Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit cda0348

Browse files
authored
Added tests for op aliases (#173)
* WIP on adding tests for op aliases * Removed aliases from stop decomposition and added alias op check in test_vjp * Removed alias op check from test_vjpvmap * Removed more divide aliases from stop decompose
1 parent 04d5ff1 commit cda0348

File tree

4 files changed

+29
-47
lines changed

4 files changed

+29
-47
lines changed

functorch/csrc/BatchRulesStopDecomposition.cpp

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,6 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
106106
STOP_DECOMPOSE(_version);
107107
STOP_DECOMPOSE(_weight_norm);
108108
STOP_DECOMPOSE(_weight_norm_differentiable_backward);
109-
STOP_DECOMPOSE(absolute);
110109
STOP_DECOMPOSE(absolute.out);
111110
STOP_DECOMPOSE(absolute_);
112111
STOP_DECOMPOSE(adaptive_max_pool1d);
@@ -127,22 +126,16 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
127126
STOP_DECOMPOSE(arange.out);
128127
STOP_DECOMPOSE(arange.start);
129128
STOP_DECOMPOSE(arange.start_step);
130-
STOP_DECOMPOSE(arccos);
131129
STOP_DECOMPOSE(arccos.out);
132130
STOP_DECOMPOSE(arccos_);
133-
STOP_DECOMPOSE(arccosh);
134131
STOP_DECOMPOSE(arccosh.out);
135132
STOP_DECOMPOSE(arccosh_);
136-
STOP_DECOMPOSE(arcsin);
137133
STOP_DECOMPOSE(arcsin.out);
138134
STOP_DECOMPOSE(arcsin_);
139-
STOP_DECOMPOSE(arcsinh);
140135
STOP_DECOMPOSE(arcsinh.out);
141136
STOP_DECOMPOSE(arcsinh_);
142-
STOP_DECOMPOSE(arctan);
143137
STOP_DECOMPOSE(arctan.out);
144138
STOP_DECOMPOSE(arctan_);
145-
STOP_DECOMPOSE(arctanh);
146139
STOP_DECOMPOSE(arctanh.out);
147140
STOP_DECOMPOSE(arctanh_);
148141
STOP_DECOMPOSE(argsort.dimname);
@@ -207,10 +200,6 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
207200
STOP_DECOMPOSE(diagflat);
208201
STOP_DECOMPOSE(diagonal.Dimname);
209202
STOP_DECOMPOSE(diff.out);
210-
STOP_DECOMPOSE(divide.Scalar);
211-
STOP_DECOMPOSE(divide.Scalar_mode);
212-
STOP_DECOMPOSE(divide.Tensor);
213-
STOP_DECOMPOSE(divide.Tensor_mode);
214203
STOP_DECOMPOSE(divide.out);
215204
STOP_DECOMPOSE(divide.out_mode);
216205
STOP_DECOMPOSE(divide_.Scalar);
@@ -272,7 +261,6 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
272261
STOP_DECOMPOSE(fft_rfftfreq.out);
273262
STOP_DECOMPOSE(fft_rfftn.out);
274263
STOP_DECOMPOSE(fill_diagonal_);
275-
STOP_DECOMPOSE(fix);
276264
STOP_DECOMPOSE(fix.out);
277265
STOP_DECOMPOSE(fix_);
278266
STOP_DECOMPOSE(flatten.DimnameList);
@@ -301,15 +289,11 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
301289
STOP_DECOMPOSE(gradient.scalarrayint);
302290
STOP_DECOMPOSE(gradient.tensorarray);
303291
STOP_DECOMPOSE(gradient.tensorarrayint);
304-
STOP_DECOMPOSE(greater.Scalar);
305292
STOP_DECOMPOSE(greater.Scalar_out);
306-
STOP_DECOMPOSE(greater.Tensor);
307293
STOP_DECOMPOSE(greater.Tensor_out);
308294
STOP_DECOMPOSE(greater_.Scalar);
309295
STOP_DECOMPOSE(greater_.Tensor);
310-
STOP_DECOMPOSE(greater_equal.Scalar);
311296
STOP_DECOMPOSE(greater_equal.Scalar_out);
312-
STOP_DECOMPOSE(greater_equal.Tensor);
313297
STOP_DECOMPOSE(greater_equal.Tensor_out);
314298
STOP_DECOMPOSE(greater_equal_.Scalar);
315299
STOP_DECOMPOSE(greater_equal_.Tensor);
@@ -362,15 +346,11 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
362346
STOP_DECOMPOSE(ldexp.Tensor);
363347
STOP_DECOMPOSE(ldexp.out);
364348
STOP_DECOMPOSE(ldexp_);
365-
STOP_DECOMPOSE(less.Scalar);
366349
STOP_DECOMPOSE(less.Scalar_out);
367-
STOP_DECOMPOSE(less.Tensor);
368350
STOP_DECOMPOSE(less.Tensor_out);
369351
STOP_DECOMPOSE(less_.Scalar);
370352
STOP_DECOMPOSE(less_.Tensor);
371-
STOP_DECOMPOSE(less_equal.Scalar);
372353
STOP_DECOMPOSE(less_equal.Scalar_out);
373-
STOP_DECOMPOSE(less_equal.Tensor);
374354
STOP_DECOMPOSE(less_equal.Tensor_out);
375355
STOP_DECOMPOSE(less_equal_.Scalar);
376356
STOP_DECOMPOSE(less_equal_.Tensor);
@@ -437,13 +417,9 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
437417
STOP_DECOMPOSE(mkldnn_convolution_backward_weights);
438418
STOP_DECOMPOSE(mode.dimname);
439419
STOP_DECOMPOSE(mode.dimname_out);
440-
STOP_DECOMPOSE(moveaxis.int);
441-
STOP_DECOMPOSE(moveaxis.intlist);
442420
STOP_DECOMPOSE(msort.out);
443421
STOP_DECOMPOSE(multilabel_margin_loss);
444422
STOP_DECOMPOSE(multilabel_margin_loss.out);
445-
STOP_DECOMPOSE(multiply.Scalar);
446-
STOP_DECOMPOSE(multiply.Tensor);
447423
STOP_DECOMPOSE(multiply.out);
448424
STOP_DECOMPOSE(multiply_.Scalar);
449425
STOP_DECOMPOSE(multiply_.Tensor);
@@ -455,7 +431,6 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
455431
STOP_DECOMPOSE(nanquantile.scalar_out);
456432
STOP_DECOMPOSE(narrow.Tensor);
457433
STOP_DECOMPOSE(native_layer_norm);
458-
STOP_DECOMPOSE(negative);
459434
STOP_DECOMPOSE(negative.out);
460435
STOP_DECOMPOSE(negative_);
461436
STOP_DECOMPOSE(new_empty);
@@ -472,9 +447,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
472447
STOP_DECOMPOSE(norm_except_dim);
473448
STOP_DECOMPOSE(normal.float_float);
474449
STOP_DECOMPOSE(normal.float_float_out);
475-
STOP_DECOMPOSE(not_equal.Scalar);
476450
STOP_DECOMPOSE(not_equal.Scalar_out);
477-
STOP_DECOMPOSE(not_equal.Tensor);
478451
STOP_DECOMPOSE(not_equal.Tensor_out);
479452
STOP_DECOMPOSE(not_equal_.Scalar);
480453
STOP_DECOMPOSE(not_equal_.Tensor);
@@ -585,12 +558,8 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
585558
STOP_DECOMPOSE(special_gammaln.out);
586559
STOP_DECOMPOSE(special_i0.out);
587560
STOP_DECOMPOSE(special_log1p.out);
588-
STOP_DECOMPOSE(special_log_softmax);
589-
STOP_DECOMPOSE(special_logit);
590561
STOP_DECOMPOSE(special_logit.out);
591-
STOP_DECOMPOSE(special_logsumexp);
592562
STOP_DECOMPOSE(special_logsumexp.out);
593-
STOP_DECOMPOSE(special_multigammaln);
594563
STOP_DECOMPOSE(special_multigammaln.out);
595564
STOP_DECOMPOSE(special_ndtr.out);
596565
STOP_DECOMPOSE(special_polygamma.out);
@@ -614,18 +583,14 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
614583
STOP_DECOMPOSE(stft);
615584
STOP_DECOMPOSE(stride.Dimname);
616585
STOP_DECOMPOSE(stride.int);
617-
STOP_DECOMPOSE(subtract.Scalar);
618-
STOP_DECOMPOSE(subtract.Tensor);
619586
STOP_DECOMPOSE(subtract.out);
620587
STOP_DECOMPOSE(subtract_.Scalar);
621588
STOP_DECOMPOSE(subtract_.Tensor);
622589
STOP_DECOMPOSE(sum.DimnameList_out);
623590
STOP_DECOMPOSE(sum.dim_DimnameList);
624591
STOP_DECOMPOSE(sum_to_size);
625592
STOP_DECOMPOSE(svd.U);
626-
STOP_DECOMPOSE(swapaxes);
627593
STOP_DECOMPOSE(swapaxes_);
628-
STOP_DECOMPOSE(swapdims);
629594
STOP_DECOMPOSE(swapdims_);
630595
STOP_DECOMPOSE(take_along_dim.out);
631596
STOP_DECOMPOSE(tensor_split.tensor_indices_or_sections);

test/functorch_lagging_op_db.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@
259259
('nn.functional.interpolate', 'trilinear'),
260260
('nn.functional.layer_norm', ''),
261261
('nn.functional.leaky_relu', ''),
262+
('nn.functional.linear', ''),
262263
('nn.functional.logsigmoid', ''),
263264
('nn.functional.max_pool2d', ''),
264265
('nn.functional.mse_loss', ''),
@@ -301,6 +302,7 @@
301302
('remainder', 'autodiffed'),
302303
('renorm', ''),
303304
('repeat', ''),
305+
('repeat_interleave', ''),
304306
('reshape', ''),
305307
('reshape_as', ''),
306308
('resize_', ''),

test/test_ops.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -206,19 +206,24 @@ def test_vjp(self, device, dtype, op):
206206
self.skipTest("Skipped! NYI: inplace-testing not supported.")
207207
return
208208

209-
for sample in samples:
210-
fn, primals = normalize_op_for_vjp(op, sample)
211-
result = fn(*primals)
212-
cotangents = tree_map(lambda x: torch.randn_like(x), result)
209+
def _test(_op):
210+
for sample in samples:
211+
fn, primals = normalize_op_for_vjp(_op, sample)
212+
result = fn(*primals)
213+
cotangents = tree_map(lambda x: torch.randn_like(x), result)
213214

214-
out, vjp_fn = vjp(fn, *primals)
215-
self.assertEqual(out, result)
216-
result_vjps = vjp_fn(cotangents)
215+
out, vjp_fn = vjp(fn, *primals)
216+
self.assertEqual(out, result)
217+
result_vjps = vjp_fn(cotangents)
217218

218-
_, vjp_fn = ref_vjp(fn, *primals)
219-
expected_vjps = vjp_fn(cotangents)
219+
_, vjp_fn = ref_vjp(fn, *primals)
220+
expected_vjps = vjp_fn(cotangents)
220221

221-
self.assertEqual(result_vjps, expected_vjps)
222+
self.assertEqual(result_vjps, expected_vjps)
223+
224+
_test(op)
225+
for a_op in op.aliases:
226+
_test(a_op)
222227

223228
@ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,))
224229
@skipOps('TestOperators', 'test_vjpvjp', vjp_fail)
@@ -452,6 +457,7 @@ def test_vmapvjp(self, device, dtype, op):
452457
xfail('put'),
453458
xfail('quantile'),
454459
xfail('renorm'),
460+
xfail('repeat_interleave'),
455461
xfail('scatter_add'),
456462
xfail('solve'),
457463
xfail('sort'),
@@ -493,7 +499,10 @@ def test():
493499
fn, args = get_vjpfull_variant(op, sample)
494500
for _ in get_fallback_and_vmap_exhaustive(fn, args, {}, compute_loop_out=False):
495501
pass
496-
502+
for a_op in op.aliases:
503+
fn, args = get_vjpfull_variant(a_op, sample)
504+
for _ in get_fallback_and_vmap_exhaustive(fn, args, {}, compute_loop_out=False):
505+
pass
497506
check_vmap_fallback(self, test, op, dry_run=False)
498507

499508
@ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,))
@@ -560,7 +569,6 @@ def test_vjpvmap(self, device, dtype, op):
560569

561570
self.assertEqual(result_vjps, expected_vjps)
562571

563-
564572
only_for = ("cpu", "cuda")
565573
instantiate_device_type_tests(TestOperators, globals(), only_for=only_for)
566574

test/test_vmap.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2980,6 +2980,9 @@ def test_vmap_exhaustive(self, device, dtype, op):
29802980
try:
29812981
for loop_out, batched_out in get_fallback_and_vmap_exhaustive(op.op, arg_values, kwarg_values):
29822982
self.assertEqual(loop_out, batched_out, atol=1e-4, rtol=1e-4)
2983+
for a_op in op.aliases:
2984+
for loop_out, batched_out in get_fallback_and_vmap_exhaustive(a_op, arg_values, kwarg_values):
2985+
self.assertEqual(loop_out, batched_out, atol=1e-4, rtol=1e-4)
29832986
except Exception as e:
29842987
# Checking if we're throwing an error because of dynamic shapes.
29852988
if "dynamic" in e.args[0]:
@@ -3045,6 +3048,7 @@ def test_vmap_exhaustive(self, device, dtype, op):
30453048
xfail('quantile'),
30463049
xfail('ravel'),
30473050
xfail('renorm'),
3051+
xfail('repeat_interleave'),
30483052
xfail('resize_as_'),
30493053
xfail('resolve_conj'),
30503054
xfail('resolve_neg'),
@@ -3083,6 +3087,9 @@ def test():
30833087
kwarg_values = sample_input.kwargs
30843088
for _ in get_fallback_and_vmap_exhaustive(op.op, arg_values, kwarg_values, compute_loop_out=False):
30853089
pass
3090+
for a_op in op.aliases:
3091+
for _ in get_fallback_and_vmap_exhaustive(a_op, arg_values, kwarg_values, compute_loop_out=False):
3092+
pass
30863093
check_vmap_fallback(self, test, op)
30873094

30883095
def test_isnan(self, device):

0 commit comments

Comments
 (0)