2
2
Load aten inputs from serialized txt files.
3
3
"""
4
4
5
- import re
6
5
import math
6
+ import re
7
7
from collections import defaultdict
8
8
from pathlib import Path
9
9
33
33
34
34
35
35
def _deserialize_tensor (size , dtype , stride = None , device = "cuda" ):
36
- if stride is not None :
37
- out = torch .empty_strided (size , stride , dtype = dtype , device = device )
38
- else :
39
- out = torch .empty (size , dtype = dtype , device = device )
36
+ kwargs = {}
40
37
if dtype in _FLOATING_TYPES :
41
- return out .copy_ (make_tensor (size , dtype = dtype , device = device , low = 0 , high = 1 ))
42
- return out .copy_ (make_tensor (size , dtype = dtype , device = device ))
38
+ kwargs .update ({"low" : 0 , "high" : 1 })
39
+ if stride is not None :
40
+ extent = 1 + sum ((size - 1 ) * stride for size , stride in zip (size , stride ))
41
+ data = make_tensor (extent , dtype = dtype , device = device , ** kwargs )
42
+ return data .as_strided (size , stride )
43
+ return make_tensor (size , dtype = dtype , device = device , ** kwargs )
43
44
44
45
45
46
def _deserialize_args (inps ):
@@ -63,20 +64,40 @@ def __init__(self, *args, **kwargs):
63
64
self .kwargs = kwargs
64
65
65
66
67
+ def _args_size (args ):
68
+ size = 0
69
+ for arg in args :
70
+ if isinstance (arg , torch .Tensor ):
71
+ size += arg .numel () * arg .element_size ()
72
+ elif isinstance (arg , (tuple , list )):
73
+ size += _args_size (arg )
74
+ return size
75
+
76
+
66
77
class TorchBenchOpTest :
67
- def __init__ (self , op , inputs ):
78
+ def __init__ (self , op , inputs , topn ):
68
79
self .op = eval (f"torch.ops.{ op } " )
69
80
self .inputs = inputs
81
+ self .topn = topn
82
+
83
+ def tests (self ):
84
+ inputs_and_sizes = []
85
+ for inp in self .inputs :
86
+ args , kwargs = _deserialize_args (inp )
87
+ size = _args_size (args ) + _args_size (list (kwargs .values ()))
88
+ inputs_and_sizes .append ((size , inp ))
89
+ ret = [x [1 ] for x in sorted (inputs_and_sizes , reverse = True )]
90
+ return ret if self .topn is None else ret [: self .topn ]
70
91
71
92
@property
72
93
def correctness_tests (self ):
73
- for inp in self .inputs :
94
+ for inp in self .tests () :
74
95
args , kwargs = _deserialize_args (inp )
75
96
yield TorchBenchTest (* args , ** kwargs )
76
97
77
98
@property
78
99
def performance_tests (self ):
79
- for inp in self .inputs :
100
+ for inp in self .tests () :
80
101
args , kwargs = _deserialize_args (inp )
81
102
yield TorchBenchTest (* args , ** kwargs )
82
103
@@ -99,8 +120,9 @@ def _parse_inputs(filename, filter, op_inputs):
99
120
100
121
101
122
class TorchBenchTestSuite :
102
- def __init__ (self , name , filename , filter = None ):
123
+ def __init__ (self , name , filename , filter = None , topn = None ):
103
124
self .name = name
125
+ self .topn = topn
104
126
self .optests = defaultdict (list )
105
127
if Path (filename ).is_dir ():
106
128
for file_path in Path (filename ).glob ("**/*.txt" ):
@@ -113,7 +135,21 @@ def __init__(self, name, filename, filter=None):
113
135
114
136
def __iter__ (self ):
115
137
for op , inputs in self .optests .items ():
116
- if any (s in op for s in ["embedding" , "scatter" , "gather" , "index" , "nll_loss" ]):
138
+ if any (
139
+ s in op
140
+ for s in [
141
+ "embedding" ,
142
+ "scatter" ,
143
+ "gather" ,
144
+ "index" ,
145
+ "nll_loss" ,
146
+ "im2col_backward" ,
147
+ "col2im_backward" ,
148
+ "native_layer_norm_backward" ,
149
+ "upsample_nearest2d_backward.vec" ,
150
+ "upsample_bilinear2d_backward.vec" ,
151
+ ]
152
+ ):
117
153
# TODO: indexing ops need valid indices
118
154
continue
119
- yield TorchBenchOpTest (op , inputs )
155
+ yield TorchBenchOpTest (op , inputs , self . topn )
0 commit comments