@@ -183,7 +183,7 @@ def safe_append(dct, key, val):
183183}
184184
185185
186- def get_top_ops (torch_threshold , nn_fn_threshold ):
186+ def get_top_ops (torch_threshold , nn_fn_threshold , with_counts = False ):
187187 denylist = set ({
188188 # These are either not real "operators", factory functions
189189 # that trivially work, or not-documented ops.
@@ -228,12 +228,17 @@ def get_top_ops(torch_threshold, nn_fn_threshold):
228228 'fft' , # is namespace
229229 })
230230
231- torch_ops = [op [ 0 ] for op in top_ops .top_torch ]
232- nn_fn_ops = [op [ 0 ] for op in top_ops .get_nn_functional_top_list ()]
233- torch_ops = [op for op in torch_ops if op not in denylist ]
234- nn_fn_ops = [op for op in nn_fn_ops if op not in denylist ]
231+ torch_ops = [op for op in top_ops .top_torch ]
232+ nn_fn_ops = [op for op in top_ops .get_nn_functional_top_list ()]
233+ torch_ops = [op for op in torch_ops if op [ 0 ] not in denylist ]
234+ nn_fn_ops = [op for op in nn_fn_ops if op [ 0 ] not in denylist ]
235235
236236 ops = torch_ops [:torch_threshold ] + nn_fn_ops [:nn_fn_threshold ]
237+
238+ # Now, sort by priority
239+ ops .sort (reverse = True , key = lambda op : op [1 ])
240+ if not with_counts :
241+ ops = [op [0 ] for op in ops ]
237242 return ops
238243
239244
@@ -341,8 +346,6 @@ def get_skipped_or_xfailed_ops_for(test_name):
341346 return result
342347
343348
344- # import pdb; pdb.set_trace()
345-
346349def get_statuses (for_subset = None , invert = False ):
347350 overridable_outplace_we_care_about = get_public_overridable_outplace_we_care_about ()
348351 if for_subset is not None :
@@ -886,3 +889,8 @@ def summary(self):
886889# result = opset.query(Operator.supports_jvpvjp, (Support.NO, Support.UNKNOWN))
887890# pprint.pprint(result)
888891# print(opset.summary())
892+
893+ # Print list of everything in order
894+ # all_ops = get_top_ops(999999, 999999, with_counts=True)
895+ # for op, count in all_ops:
896+ # print(f'{op}, {count}')
0 commit comments