Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit cd881a7

Browse files
committed
fix discover_coverage bug
1 parent 239111d commit cd881a7

File tree

1 file changed

+33
-31
lines changed

1 file changed

+33
-31
lines changed

test/discover_coverage.py

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -137,19 +137,26 @@ def get_op(dotted_name):
137137
mod = getattr(mod, name)
138138
return mod
139139

140-
# Maps function -> OpInfo
140+
# Maps function -> [OpInfo]
141141
def get_ops_covered_by_opinfos():
142142
ops = {}
143+
144+
def safe_append(dct, key, val):
145+
if key in dct:
146+
dct[key].append(val)
147+
else:
148+
dct[key] = [val]
149+
143150
for opinfo in op_db:
144151
func_op = get_op(opinfo.name)
145152
if func_op:
146-
ops[func_op] = opinfo
153+
safe_append(ops, func_op, opinfo)
147154
if opinfo.method_variant:
148-
ops[opinfo.method_variant] = opinfo
155+
safe_append(ops, opinfo.method_variant, opinfo)
149156
if opinfo.inplace_variant:
150-
ops[opinfo.inplace_variant] = opinfo
157+
safe_append(ops, opinfo.inplace_variant, opinfo)
151158
for alias in opinfo.aliases:
152-
ops[alias.op] = opinfo
159+
safe_append(ops, alias.op, opinfo)
153160
return ops
154161

155162
def get_top_ops(torch_threshold, nn_fn_threshold):
@@ -255,35 +262,30 @@ def get_statuses(for_subset=None, invert=False):
255262
op_to_opinfo = get_ops_covered_by_opinfos()
256263
result = {}
257264
x = get_covered_ops(overridable_outplace_we_care_about)
258-
for name, op in get_covered_ops(overridable_outplace_we_care_about).items():
259-
opinfo = op_to_opinfo[op]
260-
if invert == False:
261-
success = copy.deepcopy(tests)
262-
for decorator in opinfo.decorators:
263-
if not hasattr(decorator, 'test_name'):
264-
continue
265-
if decorator.test_name in tests and decorator.test_name in success:
266-
success.remove(decorator.test_name)
267-
# NB: disregard aliases, they're too much trouble
268-
for func in [opinfo.op]:
269-
if opinfo.name not in result.keys():
270-
result[name] = success
271-
else:
272-
result[name] = result[name].intersection(success)
273-
if invert == True:
274-
failures = set({})
265+
266+
def get_covered_tests(op):
267+
opinfos = op_to_opinfo[op]
268+
result = copy.deepcopy(tests)
269+
for opinfo in opinfos:
275270
for decorator in opinfo.decorators:
276271
if not hasattr(decorator, 'test_name'):
277272
continue
278-
if decorator.test_name in tests:
279-
failures.add(decorator.test_name)
280-
281-
# NB: disregard aliases, they're too much trouble
282-
for func in [opinfo.op]:
283-
if opinfo.name not in result.keys():
284-
result[name] = failures
285-
else:
286-
result[name] = result[name].union(failures)
273+
if decorator.test_name in tests and decorator.test_name in result:
274+
result.remove(decorator.test_name)
275+
return result
276+
277+
def get_all_aliases(op):
278+
opinfos = op_to_opinfo[op]
279+
result = []
280+
for opinfo in opinfos:
281+
result.append(opinfo.name)
282+
result.extend(opinfo.aliases)
283+
return set(result)
284+
285+
for name, op in get_covered_ops(overridable_outplace_we_care_about).items():
286+
successful_tests = get_covered_tests(op)
287+
failed_tests = tests - successful_tests
288+
result[name] = failed_tests if invert else successful_tests
287289
return result
288290

289291
def transpose_statuses(for_subset=None, invert=False):

0 commit comments

Comments
 (0)