|
3 | 3 | from torch.testing._internal.common_methods_invocations import op_db
|
4 | 4 | from enum import Enum
|
5 | 5 | from functorch_lagging_op_db import functorch_lagging_op_db
|
| 6 | +import functorch._src.top_operators_github_usage as top_ops |
6 | 7 |
|
7 | 8 | # Importing these files make modifications to the op_db that we need
|
8 | 9 | import test_ops
|
@@ -150,6 +151,33 @@ def get_ops_covered_by_opinfos():
|
150 | 151 | ops[alias.op] = opinfo
|
151 | 152 | return ops
|
152 | 153 |
|
| 154 | +def get_top_ops_not_covered_by_opinfo(torch_threshold=0, nn_fn_threshold=0): |
| 155 | + denylist = set({ |
| 156 | + 'tensor', 'load', 'zeros', 'no_grad', 'save', 'from_numpy', |
| 157 | + 'manual_seed', 'ones', 'randn', 'arange', 'rand', |
| 158 | + 'empty', 'randperm', 'linspace', 'set_grad_enabled', |
| 159 | + 'isnan', 'set_default_tensor_type', 'set_num_threads', |
| 160 | + 'set_printoptions', 'isfinite', 'range', 'numel', |
| 161 | + 'set_default_dtype', 'sparse_coo_tensor', 'set_rng_state', |
| 162 | + 'get_rng_state', 'get_default_dtype', 'initial_seed', |
| 163 | + 'get_num_threads', 'quantize_per_tensor', 'logspace', |
| 164 | + 'hann_window', 'is_tensor', 'as_tensor', 'randint', 'full', 'eye', |
| 165 | + 'equal', |
| 166 | + }) |
| 167 | + torch_ops = [op[0] for op in top_ops.top_torch[:torch_threshold]] |
| 168 | + nn_fn_ops = [op[0] for op in top_ops.top_nn_functional[:nn_fn_threshold]] |
| 169 | + ops = torch_ops + nn_fn_ops |
| 170 | + |
| 171 | + ops_with_opinfo = [] |
| 172 | + for op in op_db: |
| 173 | + ops_with_opinfo.append(op.name) |
| 174 | + ops_with_opinfo.extend([op.name for op in op.aliases]) |
| 175 | + ops_with_opinfo = set(ops_with_opinfo) |
| 176 | + |
| 177 | + result = [op for op in ops if op not in ops_with_opinfo] |
| 178 | + result = [op for op in result if op not in denylist] |
| 179 | + return result |
| 180 | + |
153 | 181 | def get_covered_ops(ops_list, invert=False):
|
154 | 182 | ops_covered_by_opinfo = get_ops_covered_by_opinfos()
|
155 | 183 | overridable_outplace_ops = ops_list
|
@@ -236,3 +264,11 @@ def transpose_statuses():
|
236 | 264 | method_only_ops = get_method_only_ops_we_care_about()
|
237 | 265 | # for op in method_only_ops:
|
238 | 266 | # print(f' {op},')
|
| 267 | + |
| 268 | +top_ops_not_covered_by_opinfo = get_top_ops_not_covered_by_opinfo(100, 25) |
| 269 | +for op in top_ops_not_covered_by_opinfo: |
| 270 | + print(op) |
| 271 | + |
| 272 | +# top_ops_not_covered_by_opinfo = get_top_ops_not_covered_by_opinfo(200, 50) |
| 273 | +# for op in top_ops_not_covered_by_opinfo: |
| 274 | +# print(op) |
0 commit comments