Skip to content

Commit 46bbdce

Browse files
committed
[discover_coverage] add ability to compute percentages
1 parent a8987d8 commit 46bbdce

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

test/discover_coverage.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,27 @@ def get_top_ops(torch_threshold, nn_fn_threshold):
237237
return ops
238238

239239

240+
def get_ops_percentage(torch_threshold, nn_fn_threshold):
241+
data = top_ops.top_torch + top_ops.get_nn_functional_top_list()
242+
243+
def get_num_usages(opname):
244+
# Ignore this, this is heavily inflated
245+
if opname == 't':
246+
return 0
247+
result = [op[1] for op in data if op[0] == opname]
248+
assert len(result) == 1
249+
return result[0]
250+
251+
# get all operators that are not in the denylist
252+
all_ops = get_top_ops(999999, 999999)
253+
total_op_usages = sum([get_num_usages(op) for op in all_ops])
254+
255+
# get subset of all operators
256+
subset_ops = get_top_ops(torch_threshold, nn_fn_threshold)
257+
subset_op_usages = sum([get_num_usages(op) for op in subset_ops])
258+
return subset_op_usages / total_op_usages
259+
260+
240261
def get_top_ops_not_covered_by_opinfo(torch_threshold=0, nn_fn_threshold=0):
241262
ops = get_top_ops(torch_threshold, nn_fn_threshold)
242263

@@ -811,13 +832,15 @@ def summary(self):
811832
has_no_opinfo = opset.query(Operator.has_opinfo, (False,))
812833

813834
print("=" * 30 + " Summary " + "=" * 30)
835+
print(f'% of usages on github: {get_ops_percentage(99999, 99999)}')
814836
print(opset.summary())
815837

816838
# sanity checks
817839
result = opset.query(Operator.supports_vjp, (Support.NO, Support.UNKNOWN))
818840
# pprint.pprint(result)
819841

820842
print("=" * 30 + " Top 60 Summary " + "=" * 30)
843+
print(f'% of usages on github: {get_ops_percentage(35, 25)}')
821844
opset = OperatorSet.from_top_ops_threshold(35, 25)
822845
# result = opset.query(Operator.supports_vmapjvp, (Support.NO, Support.UNKNOWN))
823846
# pprint.pprint(result)
@@ -833,6 +856,7 @@ def summary(self):
833856
print(opset.summary())
834857

835858
print("=" * 30 + " Top 125 Summary " + "=" * 30)
859+
print(f'% of usages on github: {get_ops_percentage(100, 25)}')
836860
opset = OperatorSet.from_top125()
837861
# result = opset.query(Operator.supports_vmap, (Support.NO, Support.UNKNOWN))
838862
# pprint.pprint(result)

0 commit comments

Comments
 (0)