Skip to content

Commit c2016e4

Browse files
author
Samantha Andow
authored
Fix CI (#686)
* fix ci * typo * move from vmapjvpall -> vmapjvpall_has_batch_rule
1 parent fb6f749 commit c2016e4

File tree

1 file changed

+7
-23
lines changed

1 file changed

+7
-23
lines changed

test/test_ops.py

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,7 @@ def test_vmapvjp(self, device, dtype, op):
710710
skip('nn.functional.dropout2d', ''),
711711
skip('nn.functional.feature_alpha_dropout', 'without_train'),
712712
skip('svd_lowrank', ''),
713+
xfail('stft'), # something weird is happening with shapes
713714
})
714715
def test_vmapjvp(self, device, dtype, op):
715716
if is_inplace(op, op.get_op()):
@@ -783,6 +784,8 @@ def test_vmapjvp(self, device, dtype, op):
783784
skip('pca_lowrank', ''),
784785
skip('svd_lowrank', ''),
785786
skip('nn.functional.feature_alpha_dropout', 'with_train'),
787+
788+
xfail('stft'), # transpose_ fallback
786789
}
787790

788791
@ops(functorch_lagging_op_db, allowed_dtypes=(torch.float,))
@@ -854,6 +857,10 @@ def test_vmapjvpall(self, device, dtype, op):
854857
xfail('linalg.lu_factor_ex', ''),
855858
xfail('nn.functional.feature_alpha_dropout', 'with_train'),
856859
xfail('special.log_ndtr', ''),
860+
xfail('fft.ihfft2'), # conj_physical fallback
861+
xfail('fft.ihfftn'), # conj_physical fallback
862+
xfail('istft'), # col2im fallback
863+
xfail('polar'), # complex fallback
857864
}))
858865
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
859866
def test_vmapjvpall_has_batch_rule(self, device, dtype, op):
@@ -1078,7 +1085,6 @@ def test_vjpvmap(self, device, dtype, op):
10781085
skip('nn.functional.fractional_max_pool2d'), # Random
10791086
skip('nn.functional.fractional_max_pool3d'), # Random
10801087
1081-
xfail('__rsub__', ''),
10821088
xfail('_masked.amax', ''),
10831089
xfail('_masked.amin', ''),
10841090
xfail('_masked.log_softmax', ''),
@@ -1093,25 +1099,6 @@ def test_vjpvmap(self, device, dtype, op):
10931099
xfail('cholesky', ''),
10941100
xfail('dist', ''),
10951101
xfail('eig', ''),
1096-
xfail('fft.fft', ''),
1097-
xfail('fft.fft2', ''),
1098-
xfail('fft.fftn', ''),
1099-
xfail('fft.hfft', ''),
1100-
xfail('fft.hfft2', ''),
1101-
xfail('fft.hfftn', ''),
1102-
xfail('fft.ifft', ''),
1103-
xfail('fft.ifft2', ''),
1104-
xfail('fft.ifftn', ''),
1105-
xfail('fft.ihfft', ''),
1106-
xfail('fft.ihfft2', ''),
1107-
xfail('fft.ihfftn', ''),
1108-
xfail('fft.irfft', ''),
1109-
xfail('fft.irfft2', ''),
1110-
xfail('fft.irfftn', ''),
1111-
xfail('fft.rfft', ''),
1112-
xfail('fft.rfft2', ''),
1113-
xfail('fft.rfftn', ''),
1114-
xfail('istft', ''),
11151102
xfail('linalg.det', ''),
11161103
xfail('linalg.eigh', ''),
11171104
xfail('linalg.eigvalsh', ''),
@@ -1160,14 +1147,12 @@ def test_vjpvmap(self, device, dtype, op):
11601147
xfail('norm', ''),
11611148
xfail('norm', 'fro'),
11621149
xfail('norm', 'inf'),
1163-
xfail('polar', ''),
11641150
xfail('put', ''),
11651151
xfail('renorm', ''),
11661152
xfail('softmax', ''),
11671153
xfail('softmax', 'with_dtype'),
11681154
xfail('solve', ''),
11691155
xfail('std_mean', ''),
1170-
xfail('stft', ''),
11711156
xfail('symeig', ''),
11721157
xfail('take', ''),
11731158
xfail('var_mean', ''),
@@ -1178,7 +1163,6 @@ def test_vjpvmap(self, device, dtype, op):
11781163
xfail('nn.functional.dropout2d', ''),
11791164
xfail('nn.functional.feature_alpha_dropout', 'without_train'),
11801165
xfail('svd_lowrank', ''),
1181-
xfail('rsub', ''),
11821166
xfail('linalg.lu_factor_ex', ''),
11831167
}))
11841168
def test_jvpvjp(self, device, dtype, op):

0 commit comments

Comments
 (0)