@@ -237,6 +237,27 @@ def get_top_ops(torch_threshold, nn_fn_threshold):
237
237
return ops
238
238
239
239
240
+ def get_ops_percentage (torch_threshold , nn_fn_threshold ):
241
+ data = top_ops .top_torch + top_ops .get_nn_functional_top_list ()
242
+
243
+ def get_num_usages (opname ):
244
+ # Ignore this, this is heavily inflated
245
+ if opname == 't' :
246
+ return 0
247
+ result = [op [1 ] for op in data if op [0 ] == opname ]
248
+ assert len (result ) == 1
249
+ return result [0 ]
250
+
251
+ # get all operators that are not in the denylist
252
+ all_ops = get_top_ops (999999 , 999999 )
253
+ total_op_usages = sum ([get_num_usages (op ) for op in all_ops ])
254
+
255
+ # get subset of all operators
256
+ subset_ops = get_top_ops (torch_threshold , nn_fn_threshold )
257
+ subset_op_usages = sum ([get_num_usages (op ) for op in subset_ops ])
258
+ return subset_op_usages / total_op_usages
259
+
260
+
240
261
def get_top_ops_not_covered_by_opinfo (torch_threshold = 0 , nn_fn_threshold = 0 ):
241
262
ops = get_top_ops (torch_threshold , nn_fn_threshold )
242
263
@@ -811,13 +832,15 @@ def summary(self):
811
832
has_no_opinfo = opset .query (Operator .has_opinfo , (False ,))
812
833
813
834
print ("=" * 30 + " Summary " + "=" * 30 )
835
+ print (f'% of usages on github: { get_ops_percentage (99999 , 99999 )} ' )
814
836
print (opset .summary ())
815
837
816
838
# sanity checks
817
839
result = opset .query (Operator .supports_vjp , (Support .NO , Support .UNKNOWN ))
818
840
# pprint.pprint(result)
819
841
820
842
print ("=" * 30 + " Top 60 Summary " + "=" * 30 )
843
+ print (f'% of usages on github: { get_ops_percentage (35 , 25 )} ' )
821
844
opset = OperatorSet .from_top_ops_threshold (35 , 25 )
822
845
# result = opset.query(Operator.supports_vmapjvp, (Support.NO, Support.UNKNOWN))
823
846
# pprint.pprint(result)
@@ -833,6 +856,7 @@ def summary(self):
833
856
print (opset .summary ())
834
857
835
858
print ("=" * 30 + " Top 125 Summary " + "=" * 30 )
859
+ print (f'% of usages on github: { get_ops_percentage (100 , 25 )} ' )
836
860
opset = OperatorSet .from_top125 ()
837
861
# result = opset.query(Operator.supports_vmap, (Support.NO, Support.UNKNOWN))
838
862
# pprint.pprint(result)
0 commit comments