@@ -710,6 +710,7 @@ def test_vmapvjp(self, device, dtype, op):
710
710
skip ('nn.functional.dropout2d' , '' ),
711
711
skip ('nn.functional.feature_alpha_dropout' , 'without_train' ),
712
712
skip ('svd_lowrank' , '' ),
713
+ xfail ('stft' ), # something weird is happening with shapes
713
714
})
714
715
def test_vmapjvp (self , device , dtype , op ):
715
716
if is_inplace (op , op .get_op ()):
@@ -783,6 +784,8 @@ def test_vmapjvp(self, device, dtype, op):
783
784
skip ('pca_lowrank' , '' ),
784
785
skip ('svd_lowrank' , '' ),
785
786
skip ('nn.functional.feature_alpha_dropout' , 'with_train' ),
787
+
788
+ xfail ('stft' ), # transpose_ fallback
786
789
}
787
790
788
791
@ops (functorch_lagging_op_db , allowed_dtypes = (torch .float ,))
@@ -854,6 +857,10 @@ def test_vmapjvpall(self, device, dtype, op):
854
857
xfail ('linalg.lu_factor_ex' , '' ),
855
858
xfail ('nn.functional.feature_alpha_dropout' , 'with_train' ),
856
859
xfail ('special.log_ndtr' , '' ),
860
+ xfail ('fft.ihfft2' ), # conj_physical fallback
861
+ xfail ('fft.ihfftn' ), # conj_physical fallback
862
+ xfail ('istft' ), # col2im fallback
863
+ xfail ('polar' ), # complex fallback
857
864
}))
858
865
@toleranceOverride ({torch .float32 : tol (atol = 1e-04 , rtol = 1e-04 )})
859
866
def test_vmapjvpall_has_batch_rule (self , device , dtype , op ):
@@ -1078,7 +1085,6 @@ def test_vjpvmap(self, device, dtype, op):
1078
1085
skip ('nn.functional.fractional_max_pool2d' ), # Random
1079
1086
skip ('nn.functional.fractional_max_pool3d' ), # Random
1080
1087
1081
- xfail ('__rsub__' , '' ),
1082
1088
xfail ('_masked.amax' , '' ),
1083
1089
xfail ('_masked.amin' , '' ),
1084
1090
xfail ('_masked.log_softmax' , '' ),
@@ -1093,25 +1099,6 @@ def test_vjpvmap(self, device, dtype, op):
1093
1099
xfail ('cholesky' , '' ),
1094
1100
xfail ('dist' , '' ),
1095
1101
xfail ('eig' , '' ),
1096
- xfail ('fft.fft' , '' ),
1097
- xfail ('fft.fft2' , '' ),
1098
- xfail ('fft.fftn' , '' ),
1099
- xfail ('fft.hfft' , '' ),
1100
- xfail ('fft.hfft2' , '' ),
1101
- xfail ('fft.hfftn' , '' ),
1102
- xfail ('fft.ifft' , '' ),
1103
- xfail ('fft.ifft2' , '' ),
1104
- xfail ('fft.ifftn' , '' ),
1105
- xfail ('fft.ihfft' , '' ),
1106
- xfail ('fft.ihfft2' , '' ),
1107
- xfail ('fft.ihfftn' , '' ),
1108
- xfail ('fft.irfft' , '' ),
1109
- xfail ('fft.irfft2' , '' ),
1110
- xfail ('fft.irfftn' , '' ),
1111
- xfail ('fft.rfft' , '' ),
1112
- xfail ('fft.rfft2' , '' ),
1113
- xfail ('fft.rfftn' , '' ),
1114
- xfail ('istft' , '' ),
1115
1102
xfail ('linalg.det' , '' ),
1116
1103
xfail ('linalg.eigh' , '' ),
1117
1104
xfail ('linalg.eigvalsh' , '' ),
@@ -1160,14 +1147,12 @@ def test_vjpvmap(self, device, dtype, op):
1160
1147
xfail ('norm' , '' ),
1161
1148
xfail ('norm' , 'fro' ),
1162
1149
xfail ('norm' , 'inf' ),
1163
- xfail ('polar' , '' ),
1164
1150
xfail ('put' , '' ),
1165
1151
xfail ('renorm' , '' ),
1166
1152
xfail ('softmax' , '' ),
1167
1153
xfail ('softmax' , 'with_dtype' ),
1168
1154
xfail ('solve' , '' ),
1169
1155
xfail ('std_mean' , '' ),
1170
- xfail ('stft' , '' ),
1171
1156
xfail ('symeig' , '' ),
1172
1157
xfail ('take' , '' ),
1173
1158
xfail ('var_mean' , '' ),
@@ -1178,7 +1163,6 @@ def test_vjpvmap(self, device, dtype, op):
1178
1163
xfail ('nn.functional.dropout2d' , '' ),
1179
1164
xfail ('nn.functional.feature_alpha_dropout' , 'without_train' ),
1180
1165
xfail ('svd_lowrank' , '' ),
1181
- xfail ('rsub' , '' ),
1182
1166
xfail ('linalg.lu_factor_ex' , '' ),
1183
1167
}))
1184
1168
def test_jvpvjp (self , device , dtype , op ):
0 commit comments