@@ -137,19 +137,26 @@ def get_op(dotted_name):
137
137
mod = getattr (mod , name )
138
138
return mod
139
139
140
- # Maps function -> OpInfo
140
+ # Maps function -> [ OpInfo]
141
141
def get_ops_covered_by_opinfos ():
142
142
ops = {}
143
+
144
+ def safe_append (dct , key , val ):
145
+ if key in dct :
146
+ dct [key ].append (val )
147
+ else :
148
+ dct [key ] = [val ]
149
+
143
150
for opinfo in op_db :
144
151
func_op = get_op (opinfo .name )
145
152
if func_op :
146
- ops [ func_op ] = opinfo
153
+ safe_append ( ops , func_op , opinfo )
147
154
if opinfo .method_variant :
148
- ops [ opinfo .method_variant ] = opinfo
155
+ safe_append ( ops , opinfo .method_variant , opinfo )
149
156
if opinfo .inplace_variant :
150
- ops [ opinfo .inplace_variant ] = opinfo
157
+ safe_append ( ops , opinfo .inplace_variant , opinfo )
151
158
for alias in opinfo .aliases :
152
- ops [ alias .op ] = opinfo
159
+ safe_append ( ops , alias .op , opinfo )
153
160
return ops
154
161
155
162
def get_top_ops (torch_threshold , nn_fn_threshold ):
@@ -255,35 +262,30 @@ def get_statuses(for_subset=None, invert=False):
255
262
op_to_opinfo = get_ops_covered_by_opinfos ()
256
263
result = {}
257
264
x = get_covered_ops (overridable_outplace_we_care_about )
258
- for name , op in get_covered_ops (overridable_outplace_we_care_about ).items ():
259
- opinfo = op_to_opinfo [op ]
260
- if invert == False :
261
- success = copy .deepcopy (tests )
262
- for decorator in opinfo .decorators :
263
- if not hasattr (decorator , 'test_name' ):
264
- continue
265
- if decorator .test_name in tests and decorator .test_name in success :
266
- success .remove (decorator .test_name )
267
- # NB: disregard aliases, they're too much trouble
268
- for func in [opinfo .op ]:
269
- if opinfo .name not in result .keys ():
270
- result [name ] = success
271
- else :
272
- result [name ] = result [name ].intersection (success )
273
- if invert == True :
274
- failures = set ({})
265
+
266
+ def get_covered_tests (op ):
267
+ opinfos = op_to_opinfo [op ]
268
+ result = copy .deepcopy (tests )
269
+ for opinfo in opinfos :
275
270
for decorator in opinfo .decorators :
276
271
if not hasattr (decorator , 'test_name' ):
277
272
continue
278
- if decorator .test_name in tests :
279
- failures .add (decorator .test_name )
280
-
281
- # NB: disregard aliases, they're too much trouble
282
- for func in [opinfo .op ]:
283
- if opinfo .name not in result .keys ():
284
- result [name ] = failures
285
- else :
286
- result [name ] = result [name ].union (failures )
273
+ if decorator .test_name in tests and decorator .test_name in result :
274
+ result .remove (decorator .test_name )
275
+ return result
276
+
277
+ def get_all_aliases (op ):
278
+ opinfos = op_to_opinfo [op ]
279
+ result = []
280
+ for opinfo in opinfos :
281
+ result .append (opinfo .name )
282
+ result .extend (opinfo .aliases )
283
+ return set (result )
284
+
285
+ for name , op in get_covered_ops (overridable_outplace_we_care_about ).items ():
286
+ successful_tests = get_covered_tests (op )
287
+ failed_tests = tests - successful_tests
288
+ result [name ] = failed_tests if invert else successful_tests
287
289
return result
288
290
289
291
def transpose_statuses (for_subset = None , invert = False ):
0 commit comments