22
22
]
23
23
24
24
# torch.abs, Tensor.abs, Tensor.abs_ are all considered to be different
25
- def get_public_overridable_apis (pytorch_root = '/raid/rzou/pt/whiteboard ' ):
25
+ def get_public_overridable_apis (pytorch_root = '/raid/rzou/pt/quick ' ):
26
26
results = {}
27
27
all_overridable_apis = set (torch .overrides .get_testing_overrides ().keys ())
28
28
for module , module_name , src in public_docs :
@@ -159,18 +159,25 @@ def safe_append(dct, key, val):
159
159
safe_append (ops , alias .op , opinfo )
160
160
return ops
161
161
162
+ factory_fns = {
163
+ 'tensor' , 'zeros' , 'ones' , 'randn' , 'arange' , 'rand' , 'empty' , 'randperm' ,
164
+ 'linspace' , 'logspace' , 'hann_window' , 'full' , 'eye' , 'blackman_window' ,
165
+ 'barlett_window' , 'randint' , 'range' , 'arange' ,
166
+ }
167
+
162
168
def get_top_ops (torch_threshold , nn_fn_threshold ):
163
169
denylist = set ({
164
- 'tensor' , 'load' , 'zeros' , 'no_grad' , 'save' , 'from_numpy' ,
165
- 'manual_seed' , 'ones' , 'randn' , 'arange' , 'rand' ,
166
- 'empty' , 'randperm' , 'linspace' , 'set_grad_enabled' ,
167
- 'isnan' , 'set_default_tensor_type' , 'set_num_threads' ,
168
- 'set_printoptions' , 'range' , 'numel' ,
170
+ # These are either not real "operators", factory functions
171
+ # that trivially work, or not-documented ops.
172
+ 'load' , 'no_grad' , 'save' , 'from_numpy' ,
173
+ 'manual_seed' , 'set_grad_enabled' ,
174
+ 'set_default_tensor_type' , 'set_num_threads' ,
175
+ 'set_printoptions' , 'numel' ,
169
176
'set_default_dtype' , 'sparse_coo_tensor' , 'set_rng_state' ,
170
177
'get_rng_state' , 'get_default_dtype' , 'initial_seed' ,
171
- 'get_num_threads' , 'quantize_per_tensor' , 'logspace' ,
172
- 'hann_window' , 'is_tensor' , 'as_tensor' , 'full' , 'eye' ,
173
- 'equal' , 'enable_grad' , 'seed' , 'is_storage' , 'hamming_window' ,
178
+ 'get_num_threads' , 'quantize_per_tensor' ,
179
+ 'hann_window' , 'is_tensor' , 'as_tensor' ,
180
+ 'equal' , 'enable_grad' , 'seed' , 'is_storage' ,
174
181
'is_floating_point' , 'nn.functional.torch' ,
175
182
'set_flush_denormal' , 'set_num_interop_threads' , 'dequantize' ,
176
183
'get_num_interop_threads' , 'nn.functional.math' ,
@@ -191,8 +198,6 @@ def get_top_ops(torch_threshold, nn_fn_threshold):
191
198
'nn.functional.fractional_max_pool3d_with_indices' ,
192
199
'is_complex' ,
193
200
'grad' ,
194
- 'bartlett_window' ,
195
- 'blackman_window' ,
196
201
'quantize_per_channel' ,
197
202
'nn.functional.max_pool2d_with_indices' ,
198
203
'nn.functional.max_pool3d_with_indices' ,
@@ -205,10 +210,12 @@ def get_top_ops(torch_threshold, nn_fn_threshold):
205
210
'fft' , # is namespace
206
211
})
207
212
208
- torch_ops = [op [0 ] for op in top_ops .top_torch [:torch_threshold ]]
209
- nn_fn_ops = [op [0 ] for op in top_ops .get_nn_functional_top_list ()[:nn_fn_threshold ]]
210
- ops = torch_ops + nn_fn_ops
211
- ops = [op for op in ops if op not in denylist ]
213
+ torch_ops = [op [0 ] for op in top_ops .top_torch ]
214
+ nn_fn_ops = [op [0 ] for op in top_ops .get_nn_functional_top_list ()]
215
+ torch_ops = [op for op in torch_ops if op not in denylist ]
216
+ nn_fn_ops = [op for op in nn_fn_ops if op not in denylist ]
217
+
218
+ ops = torch_ops [:torch_threshold ] + nn_fn_ops [:nn_fn_threshold ]
212
219
return ops
213
220
214
221
def get_top_ops_not_covered_by_opinfo (torch_threshold = 0 , nn_fn_threshold = 0 ):
@@ -222,6 +229,7 @@ def get_top_ops_not_covered_by_opinfo(torch_threshold=0, nn_fn_threshold=0):
222
229
223
230
result = [op for op in ops if op not in ops_with_opinfo ]
224
231
result = [op for op in result if op not in denylist ]
232
+ result = [op for op in result if op not in factory_fns ]
225
233
return result
226
234
227
235
def get_covered_ops (ops_list , invert = False ):
@@ -369,12 +377,14 @@ def print_coverage_info(th=100, nn=25):
369
377
# Allowed exemptions
370
378
vmap_exemptions = {
371
379
'torch.randn_like' , # randomness
380
+ 'torch.rand_like' , # randomness
372
381
'torch.allclose' , # number output
373
382
'torch.unique' , # dynamic
374
383
'torch.nonzero' , # dynamic
375
384
'torch.masked_select' , # dynamic
376
385
'torch.prod' , # dynamic (backward)
377
- 'torch.norm' , # norm with nuc is not commonly used.
386
+ 'torch.norm' , # norm with nuc is not commonly used; we support the other cases.
387
+ 'torch.svd' , # There isn't a bug, it is just nondeterministic so we can't test it.
378
388
}
379
389
remove_from_set (statuses ['test_vmap_exhaustive' ], vmap_exemptions )
380
390
remove_from_set (statuses ['test_vmapvjp' ], vmap_exemptions )
@@ -387,7 +397,14 @@ def print_coverage_info(th=100, nn=25):
387
397
print (f"total ops in set: { th + nn } " )
388
398
print (f"tested by OpInfo: { th + nn - len (top_ops_not_covered_by_opinfo )} " )
389
399
for test in tests :
400
+ if test in {'test_jvp' , 'test_vmapjvp' }:
401
+ continue
390
402
print (f'{ test } failing coverage { len (statuses [test ])} ' )
403
+
404
+ # We don't care about these yet
405
+ del statuses ['test_jvp' ]
406
+ del statuses ['test_vmapjvp' ]
407
+
391
408
pprint .pprint (statuses )
392
409
393
410
print_coverage_info (100 , 25 )
0 commit comments