-
Notifications
You must be signed in to change notification settings - Fork 1
Productionize the dataset we are using for BackendBench #93
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Changes from all commits
Commits
Show all changes
38 commits
Select commit
Hold shift + click to select a range
8803f09
Add tests for serialization and deserialization
PaliC 7495107
fix
PaliC e23bd3a
fix
PaliC 0d54c1c
[ez] get workflows to run on prs (#39)
PaliC 0eb0753
Grab txt file from huggingface as the default (#38)
PaliC 037f7c5
Installable backends (#27)
msaroufim 455b443
Fix flag gems tests and imports (#35)
bertmaher dd1aa1c
Fixes to kernel agent backend tests (#46)
bertmaher 5a5702a
Filter out solutions that have cuda streams (#56)
PaliC e6bb19a
Add tests for serialization and deserialization
PaliC 4b1722b
fix
PaliC 3a670c6
fix
PaliC e4ccfb8
rebase
PaliC 7618519
rebase fix
PaliC 32d52d1
rebase fix
PaliC d8c186c
Merge branch 'main' into serial
PaliC 1c18247
Adding parquet file
PaliC 1ecb1f7
filtering logic
PaliC 55bcfd6
Merge branch 'main' into parquet
PaliC a1bdf7a
cleanup
PaliC 32d3c7b
Merge branch 'main' into parquet
PaliC 8940b44
Merge branch 'parquet' of github.com:PaliC/BackendBench into parquet
PaliC 7408e7a
parquet
PaliC f535d8a
udpate deps
PaliC a68fbda
undo lint
PaliC a58f0d8
update hf upload
PaliC e8c5d1a
Mark's comments
PaliC a4c8171
Merge branch 'main' into parquet
PaliC d25c2d3
lint
PaliC 9ae0cac
stream from urls
PaliC f23690a
simplify
PaliC dbe3a8d
lint
PaliC 0705fa6
marks comments
PaliC 9094e1a
Mark's comments
PaliC 5cc096c
Mark's comments
PaliC 37d5b27
remove big inputs from dataset
PaliC 30661de
final fix
PaliC 2a50a4a
licenses
PaliC File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,4 +7,5 @@ backendbench.egg-info/ | |
CLAUDE.md | ||
venv/ | ||
ops/ | ||
datasets/ | ||
uv.lock |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,235 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD 3-Clause license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
""" | ||
Shared data loading utilities for reading trace and parquet files. | ||
""" | ||
|
||
import hashlib | ||
import logging | ||
import re | ||
from pathlib import Path | ||
from typing import Dict, List, Optional, Union | ||
|
||
import pyarrow.parquet as pq | ||
|
||
import requests | ||
import torch | ||
from BackendBench.utils import cleanup_memory_and_gpu, deserialize_args | ||
from tqdm import tqdm | ||
|
||
|
||
def _args_size(args): | ||
"""Calculate the size of arguments in bytes.""" | ||
|
||
size = 0 | ||
for arg in args: | ||
if isinstance(arg, torch.Tensor): | ||
size += arg.numel() * arg.element_size() | ||
elif isinstance(arg, (tuple, list)): | ||
size += _args_size(arg) | ||
return size | ||
|
||
|
||
def _parse_trace_file(filename: str, filter: Optional[List[str]] = None) -> List[Dict]: | ||
""" | ||
Parse a single trace file and return a list of operation dictionaries. | ||
|
||
Args: | ||
filename: Path to trace file | ||
filter: Optional list of operation name filters | ||
""" | ||
op_inputs = [] | ||
op = None | ||
|
||
with open(filename, "r") as f: | ||
lines = list(f) | ||
iterator = tqdm(lines, desc=f"Parsing {Path(filename).name}") | ||
for line in iterator: | ||
if m := re.match("Operator: (.*)", line): | ||
op = m.group(1) | ||
# this is due to a version skew error of the pytorch version we're | ||
# using for developing BackendBench and what was used in tritonbench where | ||
# SymInt didn't exist. | ||
# @todo: see if we can remove this before releasing | ||
if op == "aten.sum.SymInt": | ||
op = "aten.sum.dim_IntList" | ||
if m := re.match("cnt: \\d+, (.*)", line): | ||
assert op is not None | ||
args_str = m.group(1) | ||
cnt = int(m.group(0).split(",")[0].split(":")[1]) | ||
|
||
if filter is None or any(f in op for f in filter): | ||
args, kwargs = deserialize_args(args_str) | ||
size = _args_size(args) + _args_size(list(kwargs.values())) | ||
size = size / (1024 * 1024) # Convert to MB | ||
is_synthetic = cnt == 0 | ||
|
||
op_inputs.append( | ||
{ | ||
"uuid": hashlib.sha256(args_str.encode() + op.encode()).hexdigest(), | ||
"op_name": op, | ||
"args": args_str, | ||
"arg_size": size, | ||
"count": cnt, | ||
"is_synthetic": is_synthetic, | ||
} | ||
) | ||
return op_inputs | ||
|
||
|
||
def _parse_trace_stream( | ||
stream, filter: Optional[List[str]] = None, desc: str = "Parsing stream" | ||
) -> List[Dict]: | ||
""" | ||
Parse trace data from a text stream (e.g., from requests.Response.iter_lines()). | ||
|
||
Args: | ||
stream: Iterable of lines (strings or bytes) | ||
filter: Optional list of operation name filters | ||
desc: Description for progress bar | ||
""" | ||
op_inputs = [] | ||
op = None | ||
|
||
iterator = tqdm(stream, desc=desc) | ||
|
||
for line in iterator: | ||
# Handle bytes from response stream | ||
if isinstance(line, bytes): | ||
line = line.decode("utf-8") | ||
|
||
if m := re.match("Operator: (.*)", line): | ||
op = m.group(1) | ||
if op == "aten.sum.SymInt": | ||
op = "aten.sum.dim_IntList" | ||
if m := re.match("cnt: \\d+, (.*)", line): | ||
assert op is not None | ||
args_str = m.group(1) | ||
cnt = int(m.group(0).split(",")[0].split(":")[1]) | ||
|
||
if filter is None or any(f in op for f in filter): | ||
args, kwargs = deserialize_args(args_str) | ||
size = _args_size(args) + _args_size(list(kwargs.values())) | ||
del args, kwargs | ||
cleanup_memory_and_gpu() | ||
size = size / (1024 * 1024) # Convert to MB | ||
is_synthetic = cnt == 0 | ||
|
||
op_inputs.append( | ||
{ | ||
"uuid": hashlib.sha256(args_str.encode() + op.encode()).hexdigest(), | ||
"op_name": op, | ||
"args": args_str, | ||
"arg_size": size, | ||
"count": cnt, | ||
"is_synthetic": is_synthetic, | ||
} | ||
) | ||
return op_inputs | ||
|
||
|
||
def load_ops_from_source( | ||
source: Union[str, Path], | ||
format: str = "auto", | ||
filter: Optional[List[str]] = None, | ||
) -> List[Dict]: | ||
""" | ||
Load operation data from various sources and formats. | ||
|
||
Args: | ||
source: File path or URL | ||
format: "trace", "parquet", or "auto" (detect from file extension) | ||
filter: Optional list of operation name filters | ||
|
||
Returns: | ||
List of dictionaries with detailed operation info | ||
|
||
Auto-detection behavior: | ||
- https://domain.com/data.parquet → parquet format | ||
- https://domain.com/data.txt → trace format | ||
- https://domain.com/data → trace format (fallback) | ||
- local_file.parquet → parquet format | ||
- local_file.txt → trace format | ||
""" | ||
|
||
# Auto-detect format if not specified | ||
if format == "auto": | ||
if isinstance(source, str): | ||
# Check file extension first (works for both local files and URLs) | ||
if source.endswith(".parquet"): | ||
format = "parquet" | ||
elif source.endswith(".txt"): | ||
format = "trace" | ||
elif source.startswith(("http://", "https://")): | ||
# Remote URL without recognizable extension - default to trace | ||
format = "trace" | ||
else: | ||
raise ValueError(f"Unsupported source: {source}") | ||
else: | ||
raise ValueError(f"Unsupported source: {source}") | ||
|
||
if format == "parquet": | ||
return _load_from_parquet(source, filter) | ||
elif format == "trace": | ||
# Always load full data - consumers can extract what they need | ||
return _load_from_trace(source, filter) | ||
else: | ||
raise ValueError(f"Unsupported format: {format}") | ||
|
||
|
||
def _load_from_parquet(source: Union[str, Path], filter: Optional[List[str]]): | ||
"""Load operations from parquet file.""" | ||
table = pq.read_table(source) | ||
df = table.to_pandas() | ||
|
||
# Apply filter if provided | ||
if filter: | ||
mask = df["op_name"].apply(lambda op: any(f in op for f in filter)) | ||
df = df[mask] | ||
|
||
return df.to_dict("records") | ||
|
||
|
||
def op_list_to_benchmark_dict(ops_list: List[Dict]) -> Dict[str, List[str]]: | ||
""" | ||
Convert a list of operation dictionaries to a dictionary format which can be used for benchmarking. | ||
|
||
Args: | ||
ops_list: List of dicts with 'op_name' and 'args' keys | ||
|
||
Returns: | ||
Dictionary mapping op_name to list of args strings | ||
""" | ||
result = {} | ||
for op_data in ops_list: | ||
if not op_data["included_in_benchmark"]: | ||
continue | ||
op_name = op_data["op_name"] | ||
args = op_data["args"] | ||
if op_name not in result: | ||
result[op_name] = [] | ||
result[op_name].append(args) | ||
return result | ||
|
||
|
||
def _load_from_trace(source: Union[str, Path], filter: Optional[List[str]]) -> List[Dict]: | ||
"""Load operations from trace file(s) and return list of dicts.""" | ||
op_inputs = [] | ||
|
||
# Handle URLs - stream directly without saving to disk | ||
if isinstance(source, str) and (source.startswith("http://") or source.startswith("https://")): | ||
logging.info(f"Downloading trace from {source}") | ||
with requests.get(source, stream=True) as response: | ||
response.raise_for_status() | ||
desc = "Parsing" | ||
op_inputs = _parse_trace_stream(response.iter_lines(), filter, desc) | ||
|
||
# Handle single files | ||
else: | ||
op_inputs = _parse_trace_file(source, filter) | ||
|
||
return op_inputs |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD 3-Clause license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# Operators to skip for indexing ops that need valid indices | ||
SKIP_OPERATORS = [ | ||
"embedding", | ||
"scatter", | ||
"gather", | ||
"index", | ||
"nll_loss", | ||
"im2col_backward", | ||
"col2im_backward", | ||
"native_layer_norm_backward", | ||
"upsample_nearest2d_backward.vec", | ||
"upsample_bilinear2d_backward.vec", | ||
"_cudnn_rnn_backward.default", # RuntimeError: cuDNN error: CUDNN_STATUS_BAD_PARAM | ||
"_fft_c2c.default", # cuFFT only supports dimensions whose sizes are powers of two when computing in half precision | ||
] | ||
|
||
|
||
def apply_skip_ops_filter(ops): | ||
for op in ops: | ||
if any(skip_op in op["op_name"] for skip_op in SKIP_OPERATORS): | ||
op["included_in_benchmark"] = False | ||
op["why_excluded"].append("We cannot run this op on backendbench yet") | ||
op["runnable"] = False | ||
|
||
if op["is_synthetic"]: | ||
op["included_in_benchmark"] = False | ||
op["why_excluded"].append( | ||
"Synthetic ops are not supported in the official benchmark yet" | ||
) | ||
return ops |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't seem to be a script
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's not a script, though the next pr I'm adding after this is to add runtime data to things, and I'd want to do it here. Also we'll likely end up skipping more tests for the benchmark, so I think having all the filtering logic in one place would be smart.