27
27
opsToleranceOverride ,
28
28
check_vmap_fallback ,
29
29
)
30
+ import unittest
30
31
from torch .utils ._pytree import tree_flatten , tree_unflatten , tree_map
31
32
from functorch import grad , vjp , vmap , jacrev , jacfwd
32
33
import torch .autograd .forward_ad as fwAD
@@ -693,6 +694,9 @@ def test_vmapvjp(self, device, dtype, op):
693
694
# RuntimeError: Cannot access data pointer of Tensor that doesn't have storage
694
695
xfail ('tensor_split' ),
695
696
697
+ # https://github.com/pytorch/functorch/issues/859
698
+ xfail ('__getitem__' ),
699
+
696
700
# Causing multiple forward mode AD issues, needs investigation
697
701
xfail ('nn.functional.batch_norm' ),
698
702
xfail ('nn.functional.batch_norm' , 'without_cudnn' , device_type = 'cuda' ),
@@ -712,7 +716,6 @@ def test_vmapvjp(self, device, dtype, op):
712
716
xfail ('nn.functional.max_unpool2d' ),
713
717
xfail ('nn.functional.max_unpool3d' ),
714
718
715
- xfail ('nn.functional.embedding' ), # embedding_renorm_ does not support fwd AD
716
719
xfail ('put' ), # calls put_ during vmap with only vmaps over other, not self
717
720
})
718
721
def test_vmapjvp (self , device , dtype , op ):
@@ -758,6 +761,8 @@ def test_vmapjvp(self, device, dtype, op):
758
761
xfail ('nn.functional.batch_norm' , device_type = 'cpu' ),
759
762
xfail ('nn.functional.hinge_embedding_loss' , device_type = 'cpu' ),
760
763
764
+ # https://github.com/pytorch/functorch/issues/857
765
+ skip ('nn.functional.embedding' , '' ),
761
766
xfail ('nn.functional.soft_margin_loss' , '' ),
762
767
xfail ('nn.functional.binary_cross_entropy_with_logits' , '' ),
763
768
xfail ('linalg.householder_product' ),
@@ -785,9 +790,11 @@ def test_vmapjvp(self, device, dtype, op):
785
790
xfail ('nn.functional.max_unpool2d' ),
786
791
xfail ('nn.functional.max_unpool3d' ),
787
792
788
- xfail ('nn.functional.embedding' ), # embedding_renorm_ does not support fwd AD
789
793
xfail ('put' ), # calls put_ during vmap with only vmaps over other, not self
790
794
xfail ('nn.functional.prelu' ), # Call Tensor.as_strided
795
+
796
+ # https://github.com/pytorch/functorch/issues/859
797
+ xfail ('__getitem__' ),
791
798
}
792
799
793
800
@ops (functorch_lagging_op_db , allowed_dtypes = (torch .float ,))
@@ -872,6 +879,7 @@ def test_vmapjvpall(self, device, dtype, op):
872
879
xfail ('nn.functional.soft_margin_loss' , '' ),
873
880
xfail ('nn.functional.binary_cross_entropy_with_logits' , '' ),
874
881
xfail ('nn.functional.max_unpool1d' , 'grad' ),
882
+ xfail ('nn.functional.embedding' , '' ),
875
883
xfail ('lu_unpack' ),
876
884
xfail ('nn.functional.glu' ),
877
885
xfail ('nn.functional.bilinear' ), # trilinear doesn't have batching rule
@@ -1149,20 +1157,17 @@ def get_vjp(cotangents, *primals):
1149
1157
xfail ('nansum' , '' ),
1150
1158
xfail ('nn.functional.batch_norm' , '' ),
1151
1159
xfail ('nn.functional.batch_norm' , 'without_cudnn' , device_type = 'cuda' ),
1152
- xfail ('nn.functional.embedding' , '' ),
1160
+ xfail ('nn.functional.embedding' ),
1153
1161
xfail ('nn.functional.embedding' , 'functorch' ),
1154
1162
xfail ('nn.functional.embedding_bag' , '' ),
1155
1163
xfail ('nn.functional.grid_sample' , '' ),
1156
1164
xfail ('nn.functional.hardsigmoid' , '' ),
1157
- xfail ('nn.functional.hardswish' , '' ),
1158
1165
xfail ('nn.functional.huber_loss' , '' ),
1159
1166
xfail ('nn.functional.instance_norm' , '' ),
1160
1167
xfail ('nn.functional.logsigmoid' , '' ),
1161
1168
xfail ('nn.functional.pad' , 'circular' ),
1162
- xfail ('nn.functional.prelu' , '' ),
1163
1169
xfail ('nn.functional.softmin' , '' ),
1164
1170
xfail ('nn.functional.softmin' , 'with_dtype' ),
1165
- xfail ('nn.functional.softplus' , '' ),
1166
1171
xfail ('renorm' , '' ),
1167
1172
xfail ('std_mean' , '' ),
1168
1173
xfail ('symeig' , '' ),
@@ -1181,7 +1186,6 @@ def get_vjp(cotangents, *primals):
1181
1186
xfail ('nn.functional.pdist' , '' ),
1182
1187
xfail ('scatter_reduce' , 'sum' ),
1183
1188
xfail ('nn.functional.multi_margin_loss' , '' ),
1184
- xfail ('nn.functional.smooth_l1_loss' , '' ),
1185
1189
xfail ('scatter_reduce' , 'mean' ),
1186
1190
xfail ('scatter_reduce' , 'prod' ),
1187
1191
skip ('linalg.householder_product' , '' , device_type = 'cuda' ), # flaky, I'm not sure why
@@ -1248,8 +1252,6 @@ def reference(primals, cotangents, primals_tangents, cotangents_tangents):
1248
1252
# in slower tests.
1249
1253
FUNCTORCH_HAS_FORMULA_BUT_NOT_PYTORCH = {
1250
1254
'nn.functional.nll_loss' ,
1251
- 'nn.functional.l1_loss' ,
1252
- 'nn.functional.mse_loss' ,
1253
1255
'softmax' ,
1254
1256
'log_softmax' ,
1255
1257
'nn.functional.cross_entropy' ,
@@ -1322,6 +1324,8 @@ def test_extremal_numerics_l1_loss(self, device):
1322
1324
cotangents = torch .randn_like (result , device = device )
1323
1325
self ._compare_jacobians_of_vjp (torch .nn .functional .l1_loss , (cotangents , input , target ))
1324
1326
1327
+ # ("https://github.com/pytorch/functorch/issues/858")
1328
+ @unittest .expectedFailure
1325
1329
def test_extremal_numerics_mse_loss (self , device ):
1326
1330
N , C , H , W = 3 , 4 , 5 , 6
1327
1331
shapes = ((N , C ), (N , C , H ), (N , C , H , W ))
0 commit comments