Skip to content

Commit 4acfe6c

Browse files
PaliCmsaroufimbertmaher
authored
Productionize the dataset we are using for BackendBench (#93)
Co-authored-by: Mark Saroufim <[email protected]> Co-authored-by: Bert Maher <[email protected]>
1 parent f6b1e32 commit 4acfe6c

File tree

10 files changed

+547
-89
lines changed

10 files changed

+547
-89
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ backendbench.egg-info/
77
CLAUDE.md
88
venv/
99
ops/
10+
datasets/
1011
uv.lock
1112
pytorch_operator_coverage.csv
1213
.pre-commit-cache/

BackendBench/data_loaders.py

Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Shared data loading utilities for reading trace and parquet files.
9+
"""
10+
11+
import hashlib
12+
import logging
13+
import re
14+
from pathlib import Path
15+
from typing import Dict, List, Optional, Union
16+
17+
import pyarrow.parquet as pq
18+
19+
import requests
20+
import torch
21+
from BackendBench.utils import cleanup_memory_and_gpu, deserialize_args
22+
from tqdm import tqdm
23+
24+
25+
def _args_size(args):
26+
"""Calculate the size of arguments in bytes."""
27+
28+
size = 0
29+
for arg in args:
30+
if isinstance(arg, torch.Tensor):
31+
size += arg.numel() * arg.element_size()
32+
elif isinstance(arg, (tuple, list)):
33+
size += _args_size(arg)
34+
return size
35+
36+
37+
def _parse_trace_file(filename: str, filter: Optional[List[str]] = None) -> List[Dict]:
38+
"""
39+
Parse a single trace file and return a list of operation dictionaries.
40+
41+
Args:
42+
filename: Path to trace file
43+
filter: Optional list of operation name filters
44+
"""
45+
op_inputs = []
46+
op = None
47+
48+
with open(filename, "r") as f:
49+
lines = list(f)
50+
iterator = tqdm(lines, desc=f"Parsing {Path(filename).name}")
51+
for line in iterator:
52+
if m := re.match("Operator: (.*)", line):
53+
op = m.group(1)
54+
# this is due to a version skew error of the pytorch version we're
55+
# using for developing BackendBench and what was used in tritonbench where
56+
# SymInt didn't exist.
57+
# @todo: see if we can remove this before releasing
58+
if op == "aten.sum.SymInt":
59+
op = "aten.sum.dim_IntList"
60+
if m := re.match("cnt: \\d+, (.*)", line):
61+
assert op is not None
62+
args_str = m.group(1)
63+
cnt = int(m.group(0).split(",")[0].split(":")[1])
64+
65+
if filter is None or any(f in op for f in filter):
66+
args, kwargs = deserialize_args(args_str)
67+
size = _args_size(args) + _args_size(list(kwargs.values()))
68+
size = size / (1024 * 1024) # Convert to MB
69+
is_synthetic = cnt == 0
70+
71+
op_inputs.append(
72+
{
73+
"uuid": hashlib.sha256(args_str.encode() + op.encode()).hexdigest(),
74+
"op_name": op,
75+
"args": args_str,
76+
"arg_size": size,
77+
"count": cnt,
78+
"is_synthetic": is_synthetic,
79+
}
80+
)
81+
return op_inputs
82+
83+
84+
def _parse_trace_stream(
85+
stream, filter: Optional[List[str]] = None, desc: str = "Parsing stream"
86+
) -> List[Dict]:
87+
"""
88+
Parse trace data from a text stream (e.g., from requests.Response.iter_lines()).
89+
90+
Args:
91+
stream: Iterable of lines (strings or bytes)
92+
filter: Optional list of operation name filters
93+
desc: Description for progress bar
94+
"""
95+
op_inputs = []
96+
op = None
97+
98+
iterator = tqdm(stream, desc=desc)
99+
100+
for line in iterator:
101+
# Handle bytes from response stream
102+
if isinstance(line, bytes):
103+
line = line.decode("utf-8")
104+
105+
if m := re.match("Operator: (.*)", line):
106+
op = m.group(1)
107+
if op == "aten.sum.SymInt":
108+
op = "aten.sum.dim_IntList"
109+
if m := re.match("cnt: \\d+, (.*)", line):
110+
assert op is not None
111+
args_str = m.group(1)
112+
cnt = int(m.group(0).split(",")[0].split(":")[1])
113+
114+
if filter is None or any(f in op for f in filter):
115+
args, kwargs = deserialize_args(args_str)
116+
size = _args_size(args) + _args_size(list(kwargs.values()))
117+
del args, kwargs
118+
cleanup_memory_and_gpu()
119+
size = size / (1024 * 1024) # Convert to MB
120+
is_synthetic = cnt == 0
121+
122+
op_inputs.append(
123+
{
124+
"uuid": hashlib.sha256(args_str.encode() + op.encode()).hexdigest(),
125+
"op_name": op,
126+
"args": args_str,
127+
"arg_size": size,
128+
"count": cnt,
129+
"is_synthetic": is_synthetic,
130+
}
131+
)
132+
return op_inputs
133+
134+
135+
def load_ops_from_source(
136+
source: Union[str, Path],
137+
format: str = "auto",
138+
filter: Optional[List[str]] = None,
139+
) -> List[Dict]:
140+
"""
141+
Load operation data from various sources and formats.
142+
143+
Args:
144+
source: File path or URL
145+
format: "trace", "parquet", or "auto" (detect from file extension)
146+
filter: Optional list of operation name filters
147+
148+
Returns:
149+
List of dictionaries with detailed operation info
150+
151+
Auto-detection behavior:
152+
- https://domain.com/data.parquet → parquet format
153+
- https://domain.com/data.txt → trace format
154+
- https://domain.com/data → trace format (fallback)
155+
- local_file.parquet → parquet format
156+
- local_file.txt → trace format
157+
"""
158+
159+
# Auto-detect format if not specified
160+
if format == "auto":
161+
if isinstance(source, str):
162+
# Check file extension first (works for both local files and URLs)
163+
if source.endswith(".parquet"):
164+
format = "parquet"
165+
elif source.endswith(".txt"):
166+
format = "trace"
167+
elif source.startswith(("http://", "https://")):
168+
# Remote URL without recognizable extension - default to trace
169+
format = "trace"
170+
else:
171+
raise ValueError(f"Unsupported source: {source}")
172+
else:
173+
raise ValueError(f"Unsupported source: {source}")
174+
175+
if format == "parquet":
176+
return _load_from_parquet(source, filter)
177+
elif format == "trace":
178+
# Always load full data - consumers can extract what they need
179+
return _load_from_trace(source, filter)
180+
else:
181+
raise ValueError(f"Unsupported format: {format}")
182+
183+
184+
def _load_from_parquet(source: Union[str, Path], filter: Optional[List[str]]):
185+
"""Load operations from parquet file."""
186+
table = pq.read_table(source)
187+
df = table.to_pandas()
188+
189+
# Apply filter if provided
190+
if filter:
191+
mask = df["op_name"].apply(lambda op: any(f in op for f in filter))
192+
df = df[mask]
193+
194+
return df.to_dict("records")
195+
196+
197+
def op_list_to_benchmark_dict(ops_list: List[Dict]) -> Dict[str, List[str]]:
198+
"""
199+
Convert a list of operation dictionaries to a dictionary format which can be used for benchmarking.
200+
201+
Args:
202+
ops_list: List of dicts with 'op_name' and 'args' keys
203+
204+
Returns:
205+
Dictionary mapping op_name to list of args strings
206+
"""
207+
result = {}
208+
for op_data in ops_list:
209+
if not op_data["included_in_benchmark"]:
210+
continue
211+
op_name = op_data["op_name"]
212+
args = op_data["args"]
213+
if op_name not in result:
214+
result[op_name] = []
215+
result[op_name].append(args)
216+
return result
217+
218+
219+
def _load_from_trace(source: Union[str, Path], filter: Optional[List[str]]) -> List[Dict]:
220+
"""Load operations from trace file(s) and return list of dicts."""
221+
op_inputs = []
222+
223+
# Handle URLs - stream directly without saving to disk
224+
if isinstance(source, str) and (source.startswith("http://") or source.startswith("https://")):
225+
logging.info(f"Downloading trace from {source}")
226+
with requests.get(source, stream=True) as response:
227+
response.raise_for_status()
228+
desc = "Parsing"
229+
op_inputs = _parse_trace_stream(response.iter_lines(), filter, desc)
230+
231+
# Handle single files
232+
else:
233+
op_inputs = _parse_trace_file(source, filter)
234+
235+
return op_inputs

BackendBench/eval.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@
1515
except ImportError:
1616
TRITON_AVAILABLE = False
1717

18-
19-
from BackendBench.utils import uses_cuda_stream
20-
from BackendBench.utils import serialize_args
18+
from BackendBench.utils import serialize_args, uses_cuda_stream
2119

2220
logger = logging.getLogger(__name__)
2321

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Operators to skip for indexing ops that need valid indices
8+
SKIP_OPERATORS = [
9+
"embedding",
10+
"scatter",
11+
"gather",
12+
"index",
13+
"nll_loss",
14+
"im2col_backward",
15+
"col2im_backward",
16+
"native_layer_norm_backward",
17+
"upsample_nearest2d_backward.vec",
18+
"upsample_bilinear2d_backward.vec",
19+
"_cudnn_rnn_backward.default", # RuntimeError: cuDNN error: CUDNN_STATUS_BAD_PARAM
20+
"_fft_c2c.default", # cuFFT only supports dimensions whose sizes are powers of two when computing in half precision
21+
]
22+
23+
24+
def apply_skip_ops_filter(ops):
25+
for op in ops:
26+
if any(skip_op in op["op_name"] for skip_op in SKIP_OPERATORS):
27+
op["included_in_benchmark"] = False
28+
op["why_excluded"].append("We cannot run this op on backendbench yet")
29+
op["runnable"] = False
30+
31+
if op["is_synthetic"]:
32+
op["included_in_benchmark"] = False
33+
op["why_excluded"].append(
34+
"Synthetic ops are not supported in the official benchmark yet"
35+
)
36+
return ops

BackendBench/scripts/get_big_inputs.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import argparse
8-
import gc
98
import logging
109
import os
1110
import tempfile
@@ -26,6 +25,7 @@
2625
)
2726
from main import setup_logging
2827
from tqdm import tqdm
28+
from BackendBench.utils import cleanup_memory_and_gpu
2929

3030
# Magic numbers and constants
3131
MAX_ITERATIONS = 100 # Maximum binary search iterations to prevent infinite loops
@@ -50,16 +50,6 @@
5050
log = logging.getLogger(__name__)
5151

5252

53-
def cleanup_memory_and_gpu(*variables):
54-
"""Helper function to delete variables and clean up GPU memory"""
55-
for var in variables:
56-
if var is not None:
57-
del var
58-
torch.cuda.synchronize()
59-
torch.cuda.empty_cache()
60-
gc.collect()
61-
62-
6353
def scale_shape(shape: List[int], scale_factor: float) -> List[int]:
6454
"""Scale tensor shape by a factor"""
6555
return [max(MIN_TENSOR_DIM, int(dim * scale_factor)) for dim in shape]

BackendBench/scripts/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import BackendBench.eval as eval
1414
import click
1515
import torch
16+
1617
from BackendBench.facto_suite import FactoTestSuite
1718
from BackendBench.llm_client import ClaudeKernelGenerator, LLMKernelGenerator
1819
from BackendBench.opinfo_suite import OpInfoTestSuite

0 commit comments

Comments
 (0)