Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit ca0f044

Browse files
authored
Update functorch lagging op db (#652)
1 parent b94ece0 commit ca0f044

File tree

5 files changed

+123
-11
lines changed

5 files changed

+123
-11
lines changed

test/functorch_lagging_op_db.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
('_masked.prod', ''),
3636
('_masked.softmax', ''),
3737
('_masked.softmin', ''),
38+
('_masked.std', ''),
3839
('_masked.sum', ''),
3940
('_masked.var', ''),
4041
('abs', ''),
@@ -69,6 +70,7 @@
6970
('atleast_2d', ''),
7071
('atleast_3d', ''),
7172
('baddbmm', ''),
73+
('bernoulli', ''),
7274
('bfloat16', ''),
7375
('bfloat16', 'channels_last'),
7476
('bincount', ''),
@@ -100,6 +102,7 @@
100102
('clamp', ''),
101103
('clamp', 'scalar'),
102104
('clone', ''),
105+
('column_stack', ''),
103106
('combinations', ''),
104107
('complex', ''),
105108
('conj', ''),
@@ -120,6 +123,7 @@
120123
('deg2rad', ''),
121124
('diag', ''),
122125
('diag_embed', ''),
126+
('diagflat', ''),
123127
('diagonal', ''),
124128
('diagonal_scatter', ''),
125129
('diff', ''),
@@ -166,6 +170,7 @@
166170
('fft.rfft2', ''),
167171
('fft.rfftn', ''),
168172
('fill_', ''),
173+
('flatten', ''),
169174
('flip', ''),
170175
('fliplr', ''),
171176
('flipud', ''),
@@ -177,7 +182,6 @@
177182
('fmax', ''),
178183
('fmin', ''),
179184
('fmod', ''),
180-
('fmod', 'autodiffed'),
181185
('frac', ''),
182186
('frexp', ''),
183187
('full_like', ''),
@@ -198,9 +202,7 @@
198202
('hypot', ''),
199203
('i0', ''),
200204
('igamma', ''),
201-
('igamma', 'grad_other'),
202205
('igammac', ''),
203-
('igammac', 'grad_other'),
204206
('imag', ''),
205207
('index_add', ''),
206208
('index_copy', ''),
@@ -242,6 +244,8 @@
242244
('linalg.inv_ex', ''),
243245
('linalg.lstsq', ''),
244246
('linalg.lstsq', 'grad_oriented'),
247+
('linalg.lu_factor', ''),
248+
('linalg.lu_factor_ex', ''),
245249
('linalg.matrix_norm', ''),
246250
('linalg.matrix_power', ''),
247251
('linalg.matrix_rank', ''),
@@ -306,6 +310,7 @@
306310
('movedim', ''),
307311
('msort', ''),
308312
('mul', ''),
313+
('multinomial', ''),
309314
('mv', ''),
310315
('mvlgamma', 'mvlgamma_p_1'),
311316
('mvlgamma', 'mvlgamma_p_3'),
@@ -346,10 +351,12 @@
346351
('nn.functional.cross_entropy', ''),
347352
('nn.functional.ctc_loss', ''),
348353
('nn.functional.dropout', ''),
354+
('nn.functional.dropout2d', ''),
349355
('nn.functional.elu', ''),
350356
('nn.functional.embedding', ''),
351357
('nn.functional.embedding_bag', ''),
352-
('nn.functional.feature_alpha_dropout', ''),
358+
('nn.functional.feature_alpha_dropout', 'with_train'),
359+
('nn.functional.feature_alpha_dropout', 'without_train'),
353360
('nn.functional.fractional_max_pool2d', ''),
354361
('nn.functional.fractional_max_pool3d', ''),
355362
('nn.functional.gaussian_nll_loss', ''),
@@ -370,6 +377,7 @@
370377
('nn.functional.interpolate', 'linear'),
371378
('nn.functional.interpolate', 'nearest'),
372379
('nn.functional.interpolate', 'trilinear'),
380+
('nn.functional.kl_div', ''),
373381
('nn.functional.layer_norm', ''),
374382
('nn.functional.leaky_relu', ''),
375383
('nn.functional.linear', ''),
@@ -412,9 +420,12 @@
412420
('norm', 'fro'),
413421
('norm', 'inf'),
414422
('norm', 'nuc'),
423+
('normal', ''),
424+
('normal', 'number_mean'),
415425
('ones_like', ''),
416426
('ormqr', ''),
417427
('outer', ''),
428+
('pca_lowrank', ''),
418429
('permute', ''),
419430
('pinverse', ''),
420431
('polar', ''),
@@ -437,7 +448,6 @@
437448
('real', ''),
438449
('reciprocal', ''),
439450
('remainder', ''),
440-
('remainder', 'autodiffed'),
441451
('renorm', ''),
442452
('repeat', ''),
443453
('repeat_interleave', ''),
@@ -450,11 +460,14 @@
450460
('roll', ''),
451461
('rot90', ''),
452462
('round', ''),
463+
('round', 'decimals_0'),
464+
('round', 'decimals_3'),
465+
('round', 'decimals_neg_3'),
453466
('rsqrt', ''),
454-
('rsub', 'rsub_scalar'),
455-
('rsub', 'rsub_tensor'),
467+
('rsub', ''),
456468
('scatter', ''),
457469
('scatter_add', ''),
470+
('scatter_reduce', ''),
458471
('searchsorted', ''),
459472
('select', ''),
460473
('select_scatter', ''),
@@ -477,6 +490,7 @@
477490
('special.i0e', ''),
478491
('special.i1', ''),
479492
('special.i1e', ''),
493+
('special.log_ndtr', ''),
480494
('special.ndtr', ''),
481495
('special.ndtri', ''),
482496
('special.polygamma', 'special_polygamma_n_0'),
@@ -496,6 +510,7 @@
496510
('sum', ''),
497511
('sum_to_size', ''),
498512
('svd', ''),
513+
('svd_lowrank', ''),
499514
('symeig', ''),
500515
('t', ''),
501516
('take', ''),

test/test_ops.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,11 @@ def is_inplace(op, variant):
304304
xfail('tensor_split'),
305305
xfail('to_sparse'),
306306
xfail('nn.functional.ctc_loss'),
307+
skip('nn.functional.feature_alpha_dropout', 'with_train'), # fails on cuda, runs okay on cpu
308+
skip('nn.functional.feature_alpha_dropout', 'without_train'), # fails on cuda, runs okay on cpu
309+
skip('pca_lowrank', ''), # fails on cuda, runs okay on cpu
310+
skip('svd_lowrank', ''), # fails on cuda, runs okay on cpu
311+
skip('nn.functional.dropout2d', ''), # fails on cuda, runs okay on cpu
307312
}
308313

309314

@@ -364,6 +369,11 @@ def wrapped_fn(*args, **kwargs):
364369
skip('nn.functional.fractional_max_pool2d'), # fails on cuda, runs okay on cpu
365370
skip('nn.functional.fractional_max_pool3d'), # fails on cuda, runs okay on cpu
366371
skip('nn.functional.max_pool1d'), # fails on cpu, runs okay on cuda
372+
skip('nn.functional.feature_alpha_dropout', 'with_train'), # fails on cuda, runs okay on cpu
373+
skip('nn.functional.feature_alpha_dropout', 'without_train'), # fails on cuda, runs okay on cpu
374+
skip('pca_lowrank', ''), # fails on cuda, runs okay on cpu
375+
skip('svd_lowrank', ''), # fails on cuda, runs okay on cpu
376+
skip('nn.functional.dropout2d', ''), # fails on cuda, runs okay on cpu
367377
368378
# See https://github.com/pytorch/pytorch/issues/69034
369379
# RuntimeError: expected scalar type double but found float
@@ -394,8 +404,6 @@ def wrapped_fn(*args, **kwargs):
394404
# Some kind of issue with unsymmetric tangent type
395405
# Runtime Error: The tangent part of the matrix A should also be symmetric.
396406
xfail('linalg.eigh'),
397-
398-
399407
}))
400408
@opsToleranceOverride('TestOperators', 'test_jvp', (
401409
tol1('nn.functional.conv_transpose3d',
@@ -430,6 +438,11 @@ def test_jvp(self, device, dtype, op):
430438
@skipOps('TestOperators', 'test_vjp', vjp_fail.union({
431439
skip('nn.functional.fractional_max_pool2d'), # fails on cpu, runs okay on cuda
432440
skip('nn.functional.fractional_max_pool3d'), # fails on cpu, runs okay on cuda
441+
xfail('nn.functional.feature_alpha_dropout', 'with_train'),
442+
xfail('pca_lowrank', ''),
443+
xfail('nn.functional.dropout2d', ''),
444+
xfail('nn.functional.feature_alpha_dropout', 'without_train'),
445+
xfail('svd_lowrank', ''),
433446
}))
434447
@opsToleranceOverride('TestOperators', 'test_vjp', (
435448
tol1('nn.functional.conv_transpose3d',
@@ -613,6 +626,9 @@ def vjp_of_vjp(*args_and_cotangents):
613626
xfail('lu_solve'),
614627
xfail('index_copy'),
615628
xfail('nn.functional.gelu', device_type='cpu'),
629+
630+
xfail('linalg.lu_factor', ''),
631+
xfail('scatter_reduce', '', device_type='cpu'),
616632
})
617633

618634
@ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,))
@@ -710,6 +726,12 @@ def test_vmapvjp(self, device, dtype, op):
710726
# Some kind of issue with unsymmetric tangent type
711727
# Runtime Error: The tangent part of the matrix A should also be symmetric.
712728
xfail('linalg.eigh'),
729+
730+
skip('nn.functional.feature_alpha_dropout', 'with_train'),
731+
skip('pca_lowrank', ''),
732+
skip('nn.functional.dropout2d', ''),
733+
skip('nn.functional.feature_alpha_dropout', 'without_train'),
734+
skip('svd_lowrank', ''),
713735
})
714736
def test_vmapjvp(self, device, dtype, op):
715737
if is_inplace(op, op.get_op()):
@@ -781,6 +803,13 @@ def test_vmapjvp(self, device, dtype, op):
781803
# Some kind of issue with unsymmetric tangent type
782804
# Runtime Error: The tangent part of the matrix A should also be symmetric.
783805
xfail('linalg.eigh'),
806+
807+
xfail('linalg.lu_factor', ''),
808+
skip('nn.functional.dropout2d', ''),
809+
skip('nn.functional.feature_alpha_dropout', 'without_train'),
810+
skip('pca_lowrank', ''),
811+
skip('svd_lowrank', ''),
812+
skip('nn.functional.feature_alpha_dropout', 'with_train'),
784813
}
785814

786815
@ops(functorch_lagging_op_db, allowed_dtypes=(torch.float,))
@@ -843,6 +872,15 @@ def test_vmapjvpall(self, device, dtype, op):
843872
xfail('nn.functional.max_pool3d'),
844873
xfail('vdot'),
845874
xfail('linalg.cross'),
875+
xfail('nn.functional.feature_alpha_dropout', 'without_train'),
876+
xfail('linalg.lu_factor', ''),
877+
xfail('nn.functional.dropout2d', ''),
878+
xfail('nn.functional.kl_div', ''),
879+
xfail('pca_lowrank', ''),
880+
xfail('svd_lowrank', ''),
881+
xfail('linalg.lu_factor_ex', ''),
882+
xfail('nn.functional.feature_alpha_dropout', 'with_train'),
883+
xfail('special.log_ndtr', ''),
846884
}))
847885
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
848886
def test_vmapjvpall_has_batch_rule(self, device, dtype, op):
@@ -884,6 +922,7 @@ def test():
884922
xfail('fmax'),
885923
xfail('fft.ihfft'),
886924
xfail('fft.rfft'),
925+
xfail('special.log_ndtr'),
887926
xfail('fft.rfftn'),
888927
xfail('fill_'),
889928
xfail('index_copy'),
@@ -953,6 +992,15 @@ def test():
953992
xfail('istft'),
954993
xfail('nn.functional.fractional_max_pool2d'),
955994
xfail('linalg.tensorsolve'),
995+
xfail('linalg.lu_factor', ''),
996+
xfail('nn.functional.feature_alpha_dropout', 'with_train'),
997+
xfail('nn.functional.kl_div', ''),
998+
xfail('scatter_reduce', '', device_type='cpu'),
999+
xfail('pca_lowrank', ''),
1000+
xfail('nn.functional.dropout2d', ''),
1001+
xfail('nn.functional.feature_alpha_dropout', 'without_train'),
1002+
xfail('svd_lowrank', ''),
1003+
xfail('linalg.lu_factor_ex', ''),
9561004
}))
9571005
def test_vmapvjp_has_batch_rule(self, device, dtype, op):
9581006
if not op.supports_autograd:
@@ -1002,6 +1050,12 @@ def test():
10021050
xfail('as_strided'),
10031051
skip('nn.functional.fractional_max_pool2d'), # generator works on cpu, fails on cuda
10041052
skip('solve'),
1053+
xfail('column_stack', ''),
1054+
xfail('nn.functional.dropout2d', ''),
1055+
xfail('svd_lowrank', ''),
1056+
xfail('pca_lowrank', ''),
1057+
xfail('nn.functional.feature_alpha_dropout', 'without_train'),
1058+
xfail('nn.functional.feature_alpha_dropout', 'with_train'),
10051059
}))
10061060
def test_vjpvmap(self, device, dtype, op):
10071061
# NB: there is no vjpvmap_has_batch_rule test because that is almost
@@ -1158,6 +1212,16 @@ def test_vjpvmap(self, device, dtype, op):
11581212
xfail('symeig', ''),
11591213
xfail('take', ''),
11601214
xfail('var_mean', ''),
1215+
xfail('linalg.lu_factor', ''),
1216+
xfail('nn.functional.feature_alpha_dropout', 'with_train'),
1217+
xfail('nn.functional.kl_div', ''),
1218+
xfail('scatter_reduce', '', device_type='cpu'),
1219+
xfail('pca_lowrank', ''),
1220+
xfail('nn.functional.dropout2d', ''),
1221+
xfail('nn.functional.feature_alpha_dropout', 'without_train'),
1222+
xfail('svd_lowrank', ''),
1223+
xfail('rsub', ''),
1224+
xfail('linalg.lu_factor_ex', ''),
11611225
}))
11621226
def test_jvpvjp(self, device, dtype, op):
11631227
if not op.supports_autograd:
@@ -1256,6 +1320,7 @@ class TestDecompositionOpInfo(TestCase):
12561320
skip('stft'),
12571321
skip('_masked.softmax'),
12581322
skip('_masked.normalize'),
1323+
xfail('linalg.lu_factor', ''),
12591324
# Some weird matmul stuff with int64 matmuls
12601325
# inplace op
12611326
skip('resize_'),

test/test_pythonkey.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,12 @@ def f(x):
200200
skip('new_empty'), # nondeterministic
201201
skip('empty_like'), # nondeterministic
202202
skip('linalg.lstsq', 'grad_oriented'), # flaky
203+
xfail('normal', '', device_type='cpu'),
204+
xfail('normal', 'number_mean', device_type='cpu'),
205+
xfail('multinomial', device_type='cpu'),
206+
xfail('nn.functional.feature_alpha_dropout', 'with_train', device_type='cpu'),
207+
xfail('bernoulli', device_type='cpu'),
208+
xfail('nn.functional.dropout2d', device_type='cpu'),
203209
}
204210

205211

test/test_vmap.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3106,6 +3106,15 @@ class TestVmapOperatorsOpInfo(TestCase):
31063106
xfail('nn.functional.glu'),
31073107
xfail('nn.functional.rrelu'), # random?
31083108
xfail('__rpow__'), # https://github.com/pytorch/functorch/issues/617
3109+
xfail('bernoulli', ''),
3110+
xfail('nn.functional.feature_alpha_dropout', 'with_train'),
3111+
xfail('multinomial', ''),
3112+
xfail('column_stack', ''),
3113+
xfail('pca_lowrank', ''),
3114+
xfail('normal', ''),
3115+
xfail('nn.functional.dropout2d', ''),
3116+
xfail('normal', 'number_mean'),
3117+
xfail('svd_lowrank', ''),
31093118
}
31103119

31113120
@ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,))
@@ -3237,6 +3246,21 @@ def test_vmap_exhaustive(self, device, dtype, op):
32373246
xfail('nn.functional.bilinear'),
32383247
xfail('nn.functional.embedding_bag'),
32393248
xfail('linalg.tensorsolve'),
3249+
xfail('bernoulli', ''),
3250+
xfail('linalg.lu_factor', ''),
3251+
xfail('nn.functional.feature_alpha_dropout', 'with_train'),
3252+
xfail('nn.functional.kl_div', ''),
3253+
xfail('multinomial', ''),
3254+
xfail('scatter_reduce', '', device_type='cpu'),
3255+
xfail('column_stack', ''),
3256+
xfail('pca_lowrank', ''),
3257+
xfail('normal', ''),
3258+
xfail('nn.functional.dropout2d', ''),
3259+
xfail('normal', 'number_mean'),
3260+
xfail('svd_lowrank', ''),
3261+
xfail('linalg.lu_factor_ex', ''),
3262+
xfail('diagflat', ''),
3263+
xfail('special.log_ndtr'),
32403264
}))
32413265
def test_op_has_batch_rule(self, device, dtype, op):
32423266
def test():

test/xfail_suggester.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,16 @@ def get_failed_test(line):
2828
'test_vmapvjp_',
2929
'test_vmapvjp_has_batch_rule_',
3030
'test_vjpvmap_',
31-
'test_vmap_exhaustive_',
32-
'test_op_has_batch_rule_',
3331
'test_jvp_',
3432
'test_vmapjvp_',
33+
'test_vmapjvpall_has_batch_rule',
3534
'test_vmapjvpall_',
3635
'test_jvpvjp_',
36+
'test_vjpvjp_',
3737
'test_decomposition_',
3838
'test_make_fx_',
39+
'test_vmap_exhaustive_',
40+
'test_op_has_batch_rule_',
3941
}
4042

4143
failed_tests = [get_failed_test(line) for line in lines]

0 commit comments

Comments
 (0)