Skip to content

Commit d49a167

Browse files
authored
Grab txt file from huggingface as the default (#38)
1 parent 1b44d71 commit d49a167

File tree

4 files changed

+33
-6
lines changed

4 files changed

+33
-6
lines changed

BackendBench/torchbench_suite.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,18 @@
44

55
import math
66
import re
7+
import tempfile
78
from collections import defaultdict
89
from pathlib import Path
910

11+
import requests
1012
import torch
1113
from torch.testing import make_tensor
1214

15+
# the schema for this dataset is the one defined in tritonbench traces.
16+
# ie. https://github.com/pytorch-labs/tritonbench/blob/main/tritonbench/data/input_configs/hf_train/AlbertForMaskedLM_training.txt
17+
DEFAULT_HUGGINGFACE_URL = "https://huggingface.co/datasets/GPUMODE/huggingface_op_trace/resolve/main/tritonbench_op_trace.txt"
18+
1319

1420
dtype_abbrs = {
1521
torch.bfloat16: "bf16",
@@ -120,11 +126,29 @@ def _parse_inputs(filename, filter, op_inputs):
120126

121127

122128
class TorchBenchTestSuite:
123-
def __init__(self, name, filename, filter=None, topn=None):
129+
def __init__(self, name, filename=None, filter=None, topn=None):
124130
self.name = name
125131
self.topn = topn
126132
self.optests = defaultdict(list)
127-
if Path(filename).is_dir():
133+
134+
# Use default URL if no filename provided
135+
if filename is None:
136+
filename = DEFAULT_HUGGINGFACE_URL
137+
138+
# Check if filename is a URL
139+
if isinstance(filename, str) and (
140+
filename.startswith("http://") or filename.startswith("https://")
141+
):
142+
with (
143+
tempfile.NamedTemporaryFile(mode="w+", suffix=".txt", delete=False) as tmp_file,
144+
requests.get(filename) as response,
145+
):
146+
response.raise_for_status()
147+
tmp_file.write(response.text)
148+
tmp_file.flush()
149+
_parse_inputs(tmp_file.name, filter, self.optests)
150+
Path(tmp_file.name).unlink(missing_ok=True)
151+
elif Path(filename).is_dir():
128152
for file_path in Path(filename).glob("**/*.txt"):
129153
_parse_inputs(str(file_path), filter, self.optests)
130154
else:
@@ -148,6 +172,8 @@ def __iter__(self):
148172
"native_layer_norm_backward",
149173
"upsample_nearest2d_backward.vec",
150174
"upsample_bilinear2d_backward.vec",
175+
"_cudnn_rnn_backward.default", # RuntimeError: cuDNN error: CUDNN_STATUS_BAD_PARAM
176+
"_fft_c2c.default", # cuFFT only supports dimensions whose sizes are powers of two when computing in half precision
151177
]
152178
):
153179
# TODO: indexing ops need valid indices

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
[tool.ruff]
22
line-length = 100
33

4-
[tool.ruff.format]
4+
[tool.ruff.format]

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@ click
33
numpy
44
expecttest
55
anthropic>=0.34.0
6-
pytest
6+
pytest
7+
requests

scripts/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from BackendBench.llm_client import ClaudeKernelGenerator
1111
from BackendBench.opinfo_suite import OpInfoTestSuite
1212
from BackendBench.suite import SmokeTestSuite
13-
from BackendBench.torchbench_suite import TorchBenchTestSuite
13+
from BackendBench.torchbench_suite import DEFAULT_HUGGINGFACE_URL, TorchBenchTestSuite
1414

1515
logger = logging.getLogger(__name__)
1616

@@ -80,7 +80,7 @@ def setup_logging(log_level):
8080
)
8181
@click.option(
8282
"--torchbench-data-path",
83-
default="third_party/tritonbench/tritonbench/data/input_configs",
83+
default=DEFAULT_HUGGINGFACE_URL,
8484
type=str,
8585
help="Path to TorchBench operator data",
8686
)

0 commit comments

Comments
 (0)