Skip to content

Productionize the dataset we are using for BackendBench #93

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 37 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
8803f09
Add tests for serialization and deserialization
PaliC Jul 29, 2025
7495107
fix
PaliC Jul 30, 2025
e23bd3a
fix
PaliC Jul 30, 2025
0d54c1c
[ez] get workflows to run on prs (#39)
PaliC Jul 23, 2025
0eb0753
Grab txt file from huggingface as the default (#38)
PaliC Jul 23, 2025
037f7c5
Installable backends (#27)
msaroufim Jul 23, 2025
455b443
Fix flag gems tests and imports (#35)
bertmaher Jul 28, 2025
dd1aa1c
Fixes to kernel agent backend tests (#46)
bertmaher Jul 29, 2025
5a5702a
Filter out solutions that have cuda streams (#56)
PaliC Jul 30, 2025
e6bb19a
Add tests for serialization and deserialization
PaliC Jul 29, 2025
4b1722b
fix
PaliC Jul 30, 2025
3a670c6
fix
PaliC Jul 30, 2025
e4ccfb8
rebase
PaliC Jul 31, 2025
7618519
rebase fix
PaliC Jul 31, 2025
32d52d1
rebase fix
PaliC Jul 31, 2025
d8c186c
Merge branch 'main' into serial
PaliC Jul 31, 2025
1c18247
Adding parquet file
PaliC Jul 31, 2025
1ecb1f7
filtering logic
PaliC Aug 1, 2025
55bcfd6
Merge branch 'main' into parquet
PaliC Aug 1, 2025
a1bdf7a
cleanup
PaliC Aug 1, 2025
32d3c7b
Merge branch 'main' into parquet
PaliC Aug 1, 2025
8940b44
Merge branch 'parquet' of github.com:PaliC/BackendBench into parquet
PaliC Aug 1, 2025
7408e7a
parquet
PaliC Aug 1, 2025
f535d8a
udpate deps
PaliC Aug 1, 2025
a68fbda
undo lint
PaliC Aug 1, 2025
a58f0d8
update hf upload
PaliC Aug 1, 2025
e8c5d1a
Mark's comments
PaliC Aug 13, 2025
a4c8171
Merge branch 'main' into parquet
PaliC Aug 13, 2025
d25c2d3
lint
PaliC Aug 13, 2025
9ae0cac
stream from urls
PaliC Aug 13, 2025
f23690a
simplify
PaliC Aug 14, 2025
dbe3a8d
lint
PaliC Aug 14, 2025
0705fa6
marks comments
PaliC Aug 15, 2025
9094e1a
Mark's comments
PaliC Aug 15, 2025
5cc096c
Mark's comments
PaliC Aug 15, 2025
37d5b27
remove big inputs from dataset
PaliC Aug 18, 2025
30661de
final fix
PaliC Aug 19, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ backendbench.egg-info/
CLAUDE.md
venv/
ops/
datasets/
uv.lock
229 changes: 229 additions & 0 deletions BackendBench/data_loaders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
"""
Shared data loading utilities for reading trace and parquet files.
"""

import hashlib
import logging
import re
from pathlib import Path
from typing import Dict, List, Optional, Union

import pyarrow.parquet as pq

import requests
import torch
from BackendBench.utils import cleanup_memory_and_gpu, deserialize_args
from tqdm import tqdm


def _args_size(args):
"""Calculate the size of arguments in bytes."""

size = 0
for arg in args:
if isinstance(arg, torch.Tensor):
size += arg.numel() * arg.element_size()
elif isinstance(arg, (tuple, list)):
size += _args_size(arg)
return size


def _parse_trace_file(filename: str, filter: Optional[List[str]] = None) -> List[Dict]:
"""
Parse a single trace file and return a list of operation dictionaries.

Args:
filename: Path to trace file
filter: Optional list of operation name filters
"""
op_inputs = []
op = None

with open(filename, "r") as f:
lines = list(f)
iterator = tqdm(lines, desc=f"Parsing {Path(filename).name}")
for line in iterator:
if m := re.match("Operator: (.*)", line):
op = m.group(1)
# this is due to a version skew error of the pytorch version we're
# using for developing BackendBench and what was used in tritonbench where
# SymInt didn't exist.
# @todo: see if we can remove this before releasing
if op == "aten.sum.SymInt":
op = "aten.sum.dim_IntList"
if m := re.match("cnt: \\d+, (.*)", line):
assert op is not None
args_str = m.group(1)
cnt = int(m.group(0).split(",")[0].split(":")[1])

if filter is None or any(f in op for f in filter):
args, kwargs = deserialize_args(args_str)
size = _args_size(args) + _args_size(list(kwargs.values()))
size = size / (1024 * 1024) # Convert to MB
is_synthetic = cnt == 0

op_inputs.append(
{
"uuid": hashlib.sha256(args_str.encode() + op.encode()).hexdigest(),
"op_name": op,
"args": args_str,
"arg_size": size,
"count": cnt,
"is_synthetic": is_synthetic,
}
)
return op_inputs


def _parse_trace_stream(
stream, filter: Optional[List[str]] = None, desc: str = "Parsing stream"
) -> List[Dict]:
"""
Parse trace data from a text stream (e.g., from requests.Response.iter_lines()).

Args:
stream: Iterable of lines (strings or bytes)
filter: Optional list of operation name filters
desc: Description for progress bar
"""
op_inputs = []
op = None

iterator = tqdm(stream, desc=desc)

for line in iterator:
# Handle bytes from response stream
if isinstance(line, bytes):
line = line.decode("utf-8")

if m := re.match("Operator: (.*)", line):
op = m.group(1)
if op == "aten.sum.SymInt":
op = "aten.sum.dim_IntList"
if m := re.match("cnt: \\d+, (.*)", line):
assert op is not None
args_str = m.group(1)
cnt = int(m.group(0).split(",")[0].split(":")[1])

if filter is None or any(f in op for f in filter):
args, kwargs = deserialize_args(args_str)
size = _args_size(args) + _args_size(list(kwargs.values()))
del args, kwargs
cleanup_memory_and_gpu()
size = size / (1024 * 1024) # Convert to MB
is_synthetic = cnt == 0

op_inputs.append(
{
"uuid": hashlib.sha256(args_str.encode() + op.encode()).hexdigest(),
"op_name": op,
"args": args_str,
"arg_size": size,
"count": cnt,
"is_synthetic": is_synthetic,
}
)
return op_inputs


def load_ops_from_source(
source: Union[str, Path],
format: str = "auto",
filter: Optional[List[str]] = None,
) -> List[Dict]:
"""
Load operation data from various sources and formats.

Args:
source: File path or URL
format: "trace", "parquet", or "auto" (detect from file extension)
filter: Optional list of operation name filters

Returns:
List of dictionaries with detailed operation info

Auto-detection behavior:
- https://domain.com/data.parquet → parquet format
- https://domain.com/data.txt → trace format
- https://domain.com/data → trace format (fallback)
- local_file.parquet → parquet format
- local_file.txt → trace format
"""

# Auto-detect format if not specified
if format == "auto":
if isinstance(source, str):
# Check file extension first (works for both local files and URLs)
if source.endswith(".parquet"):
format = "parquet"
elif source.endswith(".txt"):
format = "trace"
elif source.startswith(("http://", "https://")):
# Remote URL without recognizable extension - default to trace
format = "trace"
else:
raise ValueError(f"Unsupported source: {source}")
else:
raise ValueError(f"Unsupported source: {source}")

if format == "parquet":
return _load_from_parquet(source, filter)
elif format == "trace":
# Always load full data - consumers can extract what they need
return _load_from_trace(source, filter)
else:
raise ValueError(f"Unsupported format: {format}")


def _load_from_parquet(source: Union[str, Path], filter: Optional[List[str]]):
"""Load operations from parquet file."""
table = pq.read_table(source)
df = table.to_pandas()

# Apply filter if provided
if filter:
mask = df["op_name"].apply(lambda op: any(f in op for f in filter))
df = df[mask]

return df.to_dict("records")


def op_list_to_benchmark_dict(ops_list: List[Dict]) -> Dict[str, List[str]]:
"""
Convert a list of operation dictionaries to a dictionary format which can be used for benchmarking.

Args:
ops_list: List of dicts with 'op_name' and 'args' keys

Returns:
Dictionary mapping op_name to list of args strings
"""
result = {}
for op_data in ops_list:
if not op_data["included_in_benchmark"]:
continue
op_name = op_data["op_name"]
args = op_data["args"]
if op_name not in result:
result[op_name] = []
result[op_name].append(args)
return result


def _load_from_trace(source: Union[str, Path], filter: Optional[List[str]]) -> List[Dict]:
"""Load operations from trace file(s) and return list of dicts."""
op_inputs = []

# Handle URLs - stream directly without saving to disk
if isinstance(source, str) and (source.startswith("http://") or source.startswith("https://")):
logging.info(f"Downloading trace from {source}")
with requests.get(source, stream=True) as response:
response.raise_for_status()
desc = "Parsing"
op_inputs = _parse_trace_stream(response.iter_lines(), filter, desc)

# Handle single files
else:
op_inputs = _parse_trace_file(source, filter)

return op_inputs
4 changes: 1 addition & 3 deletions BackendBench/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@

import triton.testing


from BackendBench.utils import uses_cuda_stream
from BackendBench.utils import serialize_args
from BackendBench.utils import serialize_args, uses_cuda_stream

logger = logging.getLogger(__name__)

Expand Down
30 changes: 30 additions & 0 deletions BackendBench/scripts/dataset_filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Operators to skip for indexing ops that need valid indices
SKIP_OPERATORS = [
"embedding",
"scatter",
"gather",
"index",
"nll_loss",
"im2col_backward",
"col2im_backward",
"native_layer_norm_backward",
"upsample_nearest2d_backward.vec",
"upsample_bilinear2d_backward.vec",
"_cudnn_rnn_backward.default", # RuntimeError: cuDNN error: CUDNN_STATUS_BAD_PARAM
"_fft_c2c.default", # cuFFT only supports dimensions whose sizes are powers of two when computing in half precision
]


def apply_skip_ops_filter(ops):
for op in ops:
if any(skip_op in op["op_name"] for skip_op in SKIP_OPERATORS):
op["included_in_benchmark"] = False
op["why_excluded"].append("We cannot run this op on backendbench yet")
op["runnable"] = False

if op["is_synthetic"]:
op["included_in_benchmark"] = False
op["why_excluded"].append(
"Synthetic ops are not supported in the official benchmark yet"
)
return ops
12 changes: 1 addition & 11 deletions BackendBench/scripts/get_big_inputs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import argparse
import gc
import logging
import os
import tempfile
Expand All @@ -20,6 +19,7 @@
)
from main import setup_logging
from tqdm import tqdm
from BackendBench.utils import cleanup_memory_and_gpu

# Magic numbers and constants
MAX_ITERATIONS = 100 # Maximum binary search iterations to prevent infinite loops
Expand All @@ -44,16 +44,6 @@
log = logging.getLogger(__name__)


def cleanup_memory_and_gpu(*variables):
"""Helper function to delete variables and clean up GPU memory"""
for var in variables:
if var is not None:
del var
torch.cuda.synchronize()
torch.cuda.empty_cache()
gc.collect()


def scale_shape(shape: List[int], scale_factor: float) -> List[int]:
"""Scale tensor shape by a factor"""
return [max(MIN_TENSOR_DIM, int(dim * scale_factor)) for dim in shape]
Expand Down
1 change: 1 addition & 0 deletions BackendBench/scripts/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import BackendBench.eval as eval
import click
import torch

from BackendBench.facto_suite import FactoTestSuite
from BackendBench.llm_client import ClaudeKernelGenerator, LLMKernelGenerator
from BackendBench.opinfo_suite import OpInfoTestSuite
Expand Down
Loading
Loading