Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit e690c82

Browse files
committed
update discover_coverage
1 parent ff00173 commit e690c82

File tree

1 file changed

+67
-20
lines changed

1 file changed

+67
-20
lines changed

test/discover_coverage.py

Lines changed: 67 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
]
2222

2323
# torch.abs, Tensor.abs, Tensor.abs_ are all considered to be different
24-
def get_public_overridable_apis(pytorch_root='/raid/rzou/pt/whiteboard'):
24+
def get_public_overridable_apis(pytorch_root='/raid/rzou/pt/debug-cpu'):
2525
results = {}
2626
all_overridable_apis = set(torch.overrides.get_testing_overrides().keys())
2727
for module, module_name, src in public_docs:
@@ -151,7 +151,7 @@ def get_ops_covered_by_opinfos():
151151
ops[alias.op] = opinfo
152152
return ops
153153

154-
def get_top_ops_not_covered_by_opinfo(torch_threshold=0, nn_fn_threshold=0):
154+
def get_top_ops(torch_threshold, nn_fn_threshold):
155155
denylist = set({
156156
'tensor', 'load', 'zeros', 'no_grad', 'save', 'from_numpy',
157157
'manual_seed', 'ones', 'randn', 'arange', 'rand',
@@ -165,9 +165,15 @@ def get_top_ops_not_covered_by_opinfo(torch_threshold=0, nn_fn_threshold=0):
165165
'equal', 'enable_grad', 'seed', 'is_storage', 'hamming_window',
166166
'is_floating_point', 'nn.functional.torch',
167167
})
168+
168169
torch_ops = [op[0] for op in top_ops.top_torch[:torch_threshold]]
169170
nn_fn_ops = [op[0] for op in top_ops.top_nn_functional[:nn_fn_threshold]]
170171
ops = torch_ops + nn_fn_ops
172+
ops = [op for op in ops if op not in denylist]
173+
return ops
174+
175+
def get_top_ops_not_covered_by_opinfo(torch_threshold=0, nn_fn_threshold=0):
176+
ops = get_top_ops(torch_threshold, nn_fn_threshold)
171177

172178
ops_with_opinfo = []
173179
for op in op_db:
@@ -203,29 +209,51 @@ class Status(Enum):
203209
'test_vmapvjp_has_batch_rule',
204210
}
205211

206-
def get_statuses():
212+
def get_statuses(for_subset=None, invert=False):
207213
overridable_outplace_we_care_about = get_public_overridable_outplace_we_care_about()
214+
if for_subset is not None:
215+
overridable_outplace_we_care_about = {
216+
k: v
217+
for k, v in overridable_outplace_we_care_about.items()
218+
# Removes "torch."
219+
if k[6:] in for_subset
220+
}
208221
op_to_opinfo = get_ops_covered_by_opinfos()
209222
result = {}
210223
x = get_covered_ops(overridable_outplace_we_care_about)
211224
for name, op in get_covered_ops(overridable_outplace_we_care_about).items():
212225
opinfo = op_to_opinfo[op]
213-
success = copy.deepcopy(tests)
214-
for decorator in opinfo.decorators:
215-
if not hasattr(decorator, 'test_name'):
216-
continue
217-
if decorator.test_name in tests and decorator.test_name in success:
218-
success.remove(decorator.test_name)
219-
# NB: disregard aliases, they're too much trouble
220-
for func in [opinfo.op]:
221-
if opinfo.name not in result.keys():
222-
result[name] = success
223-
else:
224-
result[name] = result[name].intersection(success)
226+
if invert == False:
227+
success = copy.deepcopy(tests)
228+
for decorator in opinfo.decorators:
229+
if not hasattr(decorator, 'test_name'):
230+
continue
231+
if decorator.test_name in tests and decorator.test_name in success:
232+
success.remove(decorator.test_name)
233+
# NB: disregard aliases, they're too much trouble
234+
for func in [opinfo.op]:
235+
if opinfo.name not in result.keys():
236+
result[name] = success
237+
else:
238+
result[name] = result[name].intersection(success)
239+
if invert == True:
240+
failures = set({})
241+
for decorator in opinfo.decorators:
242+
if not hasattr(decorator, 'test_name'):
243+
continue
244+
if decorator.test_name in tests:
245+
failures.add(decorator.test_name)
246+
247+
# NB: disregard aliases, they're too much trouble
248+
for func in [opinfo.op]:
249+
if opinfo.name not in result.keys():
250+
result[name] = failures
251+
else:
252+
result[name] = result[name].union(failures)
225253
return result
226254

227-
def transpose_statuses():
228-
statuses = get_statuses()
255+
def transpose_statuses(for_subset=None, invert=False):
256+
statuses = get_statuses(for_subset, invert=invert)
229257
result = {}
230258
for test in tests:
231259
result[test] = set({})
@@ -270,6 +298,25 @@ def transpose_statuses():
270298
# for op in top_ops_not_covered_by_opinfo:
271299
# print(op)
272300

273-
top_ops_not_covered_by_opinfo = get_top_ops_not_covered_by_opinfo(200, 40)
274-
for op in top_ops_not_covered_by_opinfo:
275-
print(op)
301+
# print("top ops not covered by opinfo: ")
302+
# top_ops_not_covered_by_opinfo = get_top_ops_not_covered_by_opinfo(200, 40)
303+
# for op in top_ops_not_covered_by_opinfo:
304+
# print('- ' + op)
305+
306+
# print("top ops not covered by opinfo: ")
307+
# top_ops_not_covered_by_opinfo = get_top_ops_not_covered_by_opinfo(200, 40)
308+
# for op in top_ops_not_covered_by_opinfo:
309+
# print('- ' + op)
310+
311+
def print_coverage_info(th=100, nn=25):
312+
print('=' * 80)
313+
print(f"top {th}, {nn} coverage")
314+
statuses = transpose_statuses(get_top_ops(th, nn), invert=True)
315+
top_ops_not_covered_by_opinfo = get_top_ops_not_covered_by_opinfo(th, nn)
316+
print(f"total ops in set: {th + nn}")
317+
print(f"tested by OpInfo: {th + nn - len(top_ops_not_covered_by_opinfo)}")
318+
for test in tests:
319+
print(f'{test} failing coverage {len(statuses[test])}')
320+
321+
print_coverage_info(100, 25)
322+
print_coverage_info(200, 50)

0 commit comments

Comments
 (0)