Skip to content

Commit 863602a

Browse files
committed
discover_coverage update
1 parent ed6787a commit 863602a

File tree

1 file changed

+33
-16
lines changed

1 file changed

+33
-16
lines changed

test/discover_coverage.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
]
2323

2424
# torch.abs, Tensor.abs, Tensor.abs_ are all considered to be different
25-
def get_public_overridable_apis(pytorch_root='/raid/rzou/pt/whiteboard'):
25+
def get_public_overridable_apis(pytorch_root='/raid/rzou/pt/quick'):
2626
results = {}
2727
all_overridable_apis = set(torch.overrides.get_testing_overrides().keys())
2828
for module, module_name, src in public_docs:
@@ -159,18 +159,25 @@ def safe_append(dct, key, val):
159159
safe_append(ops, alias.op, opinfo)
160160
return ops
161161

162+
factory_fns = {
163+
'tensor', 'zeros', 'ones', 'randn', 'arange', 'rand', 'empty', 'randperm',
164+
'linspace', 'logspace', 'hann_window', 'full', 'eye', 'blackman_window',
165+
'barlett_window', 'randint', 'range', 'arange',
166+
}
167+
162168
def get_top_ops(torch_threshold, nn_fn_threshold):
163169
denylist = set({
164-
'tensor', 'load', 'zeros', 'no_grad', 'save', 'from_numpy',
165-
'manual_seed', 'ones', 'randn', 'arange', 'rand',
166-
'empty', 'randperm', 'linspace', 'set_grad_enabled',
167-
'isnan', 'set_default_tensor_type', 'set_num_threads',
168-
'set_printoptions', 'range', 'numel',
170+
# These are either not real "operators", factory functions
171+
# that trivially work, or not-documented ops.
172+
'load', 'no_grad', 'save', 'from_numpy',
173+
'manual_seed', 'set_grad_enabled',
174+
'set_default_tensor_type', 'set_num_threads',
175+
'set_printoptions', 'numel',
169176
'set_default_dtype', 'sparse_coo_tensor', 'set_rng_state',
170177
'get_rng_state', 'get_default_dtype', 'initial_seed',
171-
'get_num_threads', 'quantize_per_tensor', 'logspace',
172-
'hann_window', 'is_tensor', 'as_tensor', 'full', 'eye',
173-
'equal', 'enable_grad', 'seed', 'is_storage', 'hamming_window',
178+
'get_num_threads', 'quantize_per_tensor',
179+
'hann_window', 'is_tensor', 'as_tensor',
180+
'equal', 'enable_grad', 'seed', 'is_storage',
174181
'is_floating_point', 'nn.functional.torch',
175182
'set_flush_denormal', 'set_num_interop_threads', 'dequantize',
176183
'get_num_interop_threads', 'nn.functional.math',
@@ -191,8 +198,6 @@ def get_top_ops(torch_threshold, nn_fn_threshold):
191198
'nn.functional.fractional_max_pool3d_with_indices',
192199
'is_complex',
193200
'grad',
194-
'bartlett_window',
195-
'blackman_window',
196201
'quantize_per_channel',
197202
'nn.functional.max_pool2d_with_indices',
198203
'nn.functional.max_pool3d_with_indices',
@@ -205,10 +210,12 @@ def get_top_ops(torch_threshold, nn_fn_threshold):
205210
'fft', # is namespace
206211
})
207212

208-
torch_ops = [op[0] for op in top_ops.top_torch[:torch_threshold]]
209-
nn_fn_ops = [op[0] for op in top_ops.get_nn_functional_top_list()[:nn_fn_threshold]]
210-
ops = torch_ops + nn_fn_ops
211-
ops = [op for op in ops if op not in denylist]
213+
torch_ops = [op[0] for op in top_ops.top_torch]
214+
nn_fn_ops = [op[0] for op in top_ops.get_nn_functional_top_list()]
215+
torch_ops = [op for op in torch_ops if op not in denylist]
216+
nn_fn_ops = [op for op in nn_fn_ops if op not in denylist]
217+
218+
ops = torch_ops[:torch_threshold] + nn_fn_ops[:nn_fn_threshold]
212219
return ops
213220

214221
def get_top_ops_not_covered_by_opinfo(torch_threshold=0, nn_fn_threshold=0):
@@ -222,6 +229,7 @@ def get_top_ops_not_covered_by_opinfo(torch_threshold=0, nn_fn_threshold=0):
222229

223230
result = [op for op in ops if op not in ops_with_opinfo]
224231
result = [op for op in result if op not in denylist]
232+
result = [op for op in result if op not in factory_fns]
225233
return result
226234

227235
def get_covered_ops(ops_list, invert=False):
@@ -369,12 +377,14 @@ def print_coverage_info(th=100, nn=25):
369377
# Allowed exemptions
370378
vmap_exemptions = {
371379
'torch.randn_like', # randomness
380+
'torch.rand_like', # randomness
372381
'torch.allclose', # number output
373382
'torch.unique', # dynamic
374383
'torch.nonzero', # dynamic
375384
'torch.masked_select', # dynamic
376385
'torch.prod', # dynamic (backward)
377-
'torch.norm', # norm with nuc is not commonly used.
386+
'torch.norm', # norm with nuc is not commonly used; we support the other cases.
387+
'torch.svd', # There isn't a bug, it is just nondeterministic so we can't test it.
378388
}
379389
remove_from_set(statuses['test_vmap_exhaustive'], vmap_exemptions)
380390
remove_from_set(statuses['test_vmapvjp'], vmap_exemptions)
@@ -387,7 +397,14 @@ def print_coverage_info(th=100, nn=25):
387397
print(f"total ops in set: {th + nn}")
388398
print(f"tested by OpInfo: {th + nn - len(top_ops_not_covered_by_opinfo)}")
389399
for test in tests:
400+
if test in {'test_jvp', 'test_vmapjvp'}:
401+
continue
390402
print(f'{test} failing coverage {len(statuses[test])}')
403+
404+
# We don't care about these yet
405+
del statuses['test_jvp']
406+
del statuses['test_vmapjvp']
407+
391408
pprint.pprint(statuses)
392409

393410
print_coverage_info(100, 25)

0 commit comments

Comments
 (0)