21
21
]
22
22
23
23
# torch.abs, Tensor.abs, Tensor.abs_ are all considered to be different
24
- def get_public_overridable_apis (pytorch_root = '/raid/rzou/pt/whiteboard ' ):
24
+ def get_public_overridable_apis (pytorch_root = '/raid/rzou/pt/debug-cpu ' ):
25
25
results = {}
26
26
all_overridable_apis = set (torch .overrides .get_testing_overrides ().keys ())
27
27
for module , module_name , src in public_docs :
@@ -151,7 +151,7 @@ def get_ops_covered_by_opinfos():
151
151
ops [alias .op ] = opinfo
152
152
return ops
153
153
154
- def get_top_ops_not_covered_by_opinfo (torch_threshold = 0 , nn_fn_threshold = 0 ):
154
+ def get_top_ops (torch_threshold , nn_fn_threshold ):
155
155
denylist = set ({
156
156
'tensor' , 'load' , 'zeros' , 'no_grad' , 'save' , 'from_numpy' ,
157
157
'manual_seed' , 'ones' , 'randn' , 'arange' , 'rand' ,
@@ -165,9 +165,15 @@ def get_top_ops_not_covered_by_opinfo(torch_threshold=0, nn_fn_threshold=0):
165
165
'equal' , 'enable_grad' , 'seed' , 'is_storage' , 'hamming_window' ,
166
166
'is_floating_point' , 'nn.functional.torch' ,
167
167
})
168
+
168
169
torch_ops = [op [0 ] for op in top_ops .top_torch [:torch_threshold ]]
169
170
nn_fn_ops = [op [0 ] for op in top_ops .top_nn_functional [:nn_fn_threshold ]]
170
171
ops = torch_ops + nn_fn_ops
172
+ ops = [op for op in ops if op not in denylist ]
173
+ return ops
174
+
175
+ def get_top_ops_not_covered_by_opinfo (torch_threshold = 0 , nn_fn_threshold = 0 ):
176
+ ops = get_top_ops (torch_threshold , nn_fn_threshold )
171
177
172
178
ops_with_opinfo = []
173
179
for op in op_db :
@@ -203,29 +209,51 @@ class Status(Enum):
203
209
'test_vmapvjp_has_batch_rule' ,
204
210
}
205
211
206
- def get_statuses ():
212
+ def get_statuses (for_subset = None , invert = False ):
207
213
overridable_outplace_we_care_about = get_public_overridable_outplace_we_care_about ()
214
+ if for_subset is not None :
215
+ overridable_outplace_we_care_about = {
216
+ k : v
217
+ for k , v in overridable_outplace_we_care_about .items ()
218
+ # Removes "torch."
219
+ if k [6 :] in for_subset
220
+ }
208
221
op_to_opinfo = get_ops_covered_by_opinfos ()
209
222
result = {}
210
223
x = get_covered_ops (overridable_outplace_we_care_about )
211
224
for name , op in get_covered_ops (overridable_outplace_we_care_about ).items ():
212
225
opinfo = op_to_opinfo [op ]
213
- success = copy .deepcopy (tests )
214
- for decorator in opinfo .decorators :
215
- if not hasattr (decorator , 'test_name' ):
216
- continue
217
- if decorator .test_name in tests and decorator .test_name in success :
218
- success .remove (decorator .test_name )
219
- # NB: disregard aliases, they're too much trouble
220
- for func in [opinfo .op ]:
221
- if opinfo .name not in result .keys ():
222
- result [name ] = success
223
- else :
224
- result [name ] = result [name ].intersection (success )
226
+ if invert == False :
227
+ success = copy .deepcopy (tests )
228
+ for decorator in opinfo .decorators :
229
+ if not hasattr (decorator , 'test_name' ):
230
+ continue
231
+ if decorator .test_name in tests and decorator .test_name in success :
232
+ success .remove (decorator .test_name )
233
+ # NB: disregard aliases, they're too much trouble
234
+ for func in [opinfo .op ]:
235
+ if opinfo .name not in result .keys ():
236
+ result [name ] = success
237
+ else :
238
+ result [name ] = result [name ].intersection (success )
239
+ if invert == True :
240
+ failures = set ({})
241
+ for decorator in opinfo .decorators :
242
+ if not hasattr (decorator , 'test_name' ):
243
+ continue
244
+ if decorator .test_name in tests :
245
+ failures .add (decorator .test_name )
246
+
247
+ # NB: disregard aliases, they're too much trouble
248
+ for func in [opinfo .op ]:
249
+ if opinfo .name not in result .keys ():
250
+ result [name ] = failures
251
+ else :
252
+ result [name ] = result [name ].union (failures )
225
253
return result
226
254
227
- def transpose_statuses ():
228
- statuses = get_statuses ()
255
+ def transpose_statuses (for_subset = None , invert = False ):
256
+ statuses = get_statuses (for_subset , invert = invert )
229
257
result = {}
230
258
for test in tests :
231
259
result [test ] = set ({})
@@ -270,6 +298,25 @@ def transpose_statuses():
270
298
# for op in top_ops_not_covered_by_opinfo:
271
299
# print(op)
272
300
273
- top_ops_not_covered_by_opinfo = get_top_ops_not_covered_by_opinfo (200 , 40 )
274
- for op in top_ops_not_covered_by_opinfo :
275
- print (op )
301
+ # print("top ops not covered by opinfo: ")
302
+ # top_ops_not_covered_by_opinfo = get_top_ops_not_covered_by_opinfo(200, 40)
303
+ # for op in top_ops_not_covered_by_opinfo:
304
+ # print('- ' + op)
305
+
306
+ # print("top ops not covered by opinfo: ")
307
+ # top_ops_not_covered_by_opinfo = get_top_ops_not_covered_by_opinfo(200, 40)
308
+ # for op in top_ops_not_covered_by_opinfo:
309
+ # print('- ' + op)
310
+
311
+ def print_coverage_info (th = 100 , nn = 25 ):
312
+ print ('=' * 80 )
313
+ print (f"top { th } , { nn } coverage" )
314
+ statuses = transpose_statuses (get_top_ops (th , nn ), invert = True )
315
+ top_ops_not_covered_by_opinfo = get_top_ops_not_covered_by_opinfo (th , nn )
316
+ print (f"total ops in set: { th + nn } " )
317
+ print (f"tested by OpInfo: { th + nn - len (top_ops_not_covered_by_opinfo )} " )
318
+ for test in tests :
319
+ print (f'{ test } failing coverage { len (statuses [test ])} ' )
320
+
321
+ print_coverage_info (100 , 25 )
322
+ print_coverage_info (200 , 50 )
0 commit comments