4
4
5
5
import math
6
6
import re
7
+ import tempfile
7
8
from collections import defaultdict
8
9
from pathlib import Path
9
10
11
+ import requests
10
12
import torch
11
13
from torch .testing import make_tensor
12
14
15
+ # the schema for this dataset is the one defined in tritonbench traces.
16
+ # ie. https://github.com/pytorch-labs/tritonbench/blob/main/tritonbench/data/input_configs/hf_train/AlbertForMaskedLM_training.txt
17
+ DEFAULT_HUGGINGFACE_URL = "https://huggingface.co/datasets/GPUMODE/huggingface_op_trace/resolve/main/tritonbench_op_trace.txt"
18
+
13
19
14
20
dtype_abbrs = {
15
21
torch .bfloat16 : "bf16" ,
@@ -120,11 +126,29 @@ def _parse_inputs(filename, filter, op_inputs):
120
126
121
127
122
128
class TorchBenchTestSuite :
123
- def __init__ (self , name , filename , filter = None , topn = None ):
129
+ def __init__ (self , name , filename = None , filter = None , topn = None ):
124
130
self .name = name
125
131
self .topn = topn
126
132
self .optests = defaultdict (list )
127
- if Path (filename ).is_dir ():
133
+
134
+ # Use default URL if no filename provided
135
+ if filename is None :
136
+ filename = DEFAULT_HUGGINGFACE_URL
137
+
138
+ # Check if filename is a URL
139
+ if isinstance (filename , str ) and (
140
+ filename .startswith ("http://" ) or filename .startswith ("https://" )
141
+ ):
142
+ with (
143
+ tempfile .NamedTemporaryFile (mode = "w+" , suffix = ".txt" , delete = False ) as tmp_file ,
144
+ requests .get (filename ) as response ,
145
+ ):
146
+ response .raise_for_status ()
147
+ tmp_file .write (response .text )
148
+ tmp_file .flush ()
149
+ _parse_inputs (tmp_file .name , filter , self .optests )
150
+ Path (tmp_file .name ).unlink (missing_ok = True )
151
+ elif Path (filename ).is_dir ():
128
152
for file_path in Path (filename ).glob ("**/*.txt" ):
129
153
_parse_inputs (str (file_path ), filter , self .optests )
130
154
else :
@@ -148,6 +172,8 @@ def __iter__(self):
148
172
"native_layer_norm_backward" ,
149
173
"upsample_nearest2d_backward.vec" ,
150
174
"upsample_bilinear2d_backward.vec" ,
175
+ "_cudnn_rnn_backward.default" , # RuntimeError: cuDNN error: CUDNN_STATUS_BAD_PARAM
176
+ "_fft_c2c.default" , # cuFFT only supports dimensions whose sizes are powers of two when computing in half precision
151
177
]
152
178
):
153
179
# TODO: indexing ops need valid indices
0 commit comments