Skip to content

Commit 2bbedd1

Browse files
authored
Fix CI (#854)
- removed xfails - removed some decompositions b/c we have support in core
1 parent f945483 commit 2bbedd1

File tree

4 files changed

+16
-13
lines changed

4 files changed

+16
-13
lines changed

functorch/_src/eager_transforms.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1328,8 +1328,6 @@ def _register_python_decomposition_vmap(decomp):
13281328
_register_jit_decomposition(torch.ops.aten.trace.default)
13291329
_register_jit_decomposition(torch.ops.aten.nll_loss_backward.default)
13301330
_register_jit_decomposition(torch.ops.aten.nll_loss2d_backward.default)
1331-
_register_jit_decomposition(torch.ops.aten.mse_loss_backward.default)
1332-
_register_jit_decomposition(torch.ops.aten.l1_loss_backward.default)
13331331
_register_jit_decomposition(torch.ops.aten._log_softmax_backward_data.default)
13341332
_register_jit_decomposition(torch.ops.aten._softmax_backward_data.default)
13351333
_register_jit_decomposition(torch.ops.aten.log_sigmoid_forward.default)

functorch/csrc/DynamicLayer.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -492,8 +492,6 @@ TORCH_LIBRARY_IMPL(_, FT_DYNAMIC_LAYER_BACK_MODE_KEY, m) {
492492
TORCH_LIBRARY_IMPL(aten, FT_DYNAMIC_LAYER_FRONT_MODE_KEY, m) {
493493
JVP_DECOMP(nll_loss_backward);
494494
JVP_DECOMP(nll_loss2d_backward);
495-
JVP_DECOMP(mse_loss_backward);
496-
JVP_DECOMP(l1_loss_backward);
497495
JVP_DECOMP(_log_softmax_backward_data);
498496
JVP_DECOMP(_softmax_backward_data);
499497
OP_DECOMPOSE(log_sigmoid);

test/test_ops.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
opsToleranceOverride,
2828
check_vmap_fallback,
2929
)
30+
import unittest
3031
from torch.utils._pytree import tree_flatten, tree_unflatten, tree_map
3132
from functorch import grad, vjp, vmap, jacrev, jacfwd
3233
import torch.autograd.forward_ad as fwAD
@@ -693,6 +694,9 @@ def test_vmapvjp(self, device, dtype, op):
693694
# RuntimeError: Cannot access data pointer of Tensor that doesn't have storage
694695
xfail('tensor_split'),
695696
697+
# https://github.com/pytorch/functorch/issues/859
698+
xfail('__getitem__'),
699+
696700
# Causing multiple forward mode AD issues, needs investigation
697701
xfail('nn.functional.batch_norm'),
698702
xfail('nn.functional.batch_norm', 'without_cudnn', device_type='cuda'),
@@ -712,7 +716,6 @@ def test_vmapvjp(self, device, dtype, op):
712716
xfail('nn.functional.max_unpool2d'),
713717
xfail('nn.functional.max_unpool3d'),
714718
715-
xfail('nn.functional.embedding'), # embedding_renorm_ does not support fwd AD
716719
xfail('put'), # calls put_ during vmap with only vmaps over other, not self
717720
})
718721
def test_vmapjvp(self, device, dtype, op):
@@ -758,6 +761,8 @@ def test_vmapjvp(self, device, dtype, op):
758761
xfail('nn.functional.batch_norm', device_type='cpu'),
759762
xfail('nn.functional.hinge_embedding_loss', device_type='cpu'),
760763

764+
# https://github.com/pytorch/functorch/issues/857
765+
skip('nn.functional.embedding', ''),
761766
xfail('nn.functional.soft_margin_loss', ''),
762767
xfail('nn.functional.binary_cross_entropy_with_logits', ''),
763768
xfail('linalg.householder_product'),
@@ -785,9 +790,11 @@ def test_vmapjvp(self, device, dtype, op):
785790
xfail('nn.functional.max_unpool2d'),
786791
xfail('nn.functional.max_unpool3d'),
787792

788-
xfail('nn.functional.embedding'), # embedding_renorm_ does not support fwd AD
789793
xfail('put'), # calls put_ during vmap with only vmaps over other, not self
790794
xfail('nn.functional.prelu'), # Call Tensor.as_strided
795+
796+
# https://github.com/pytorch/functorch/issues/859
797+
xfail('__getitem__'),
791798
}
792799

793800
@ops(functorch_lagging_op_db, allowed_dtypes=(torch.float,))
@@ -872,6 +879,7 @@ def test_vmapjvpall(self, device, dtype, op):
872879
xfail('nn.functional.soft_margin_loss', ''),
873880
xfail('nn.functional.binary_cross_entropy_with_logits', ''),
874881
xfail('nn.functional.max_unpool1d', 'grad'),
882+
xfail('nn.functional.embedding', ''),
875883
xfail('lu_unpack'),
876884
xfail('nn.functional.glu'),
877885
xfail('nn.functional.bilinear'), # trilinear doesn't have batching rule
@@ -1149,20 +1157,17 @@ def get_vjp(cotangents, *primals):
11491157
xfail('nansum', ''),
11501158
xfail('nn.functional.batch_norm', ''),
11511159
xfail('nn.functional.batch_norm', 'without_cudnn', device_type='cuda'),
1152-
xfail('nn.functional.embedding', ''),
1160+
xfail('nn.functional.embedding'),
11531161
xfail('nn.functional.embedding', 'functorch'),
11541162
xfail('nn.functional.embedding_bag', ''),
11551163
xfail('nn.functional.grid_sample', ''),
11561164
xfail('nn.functional.hardsigmoid', ''),
1157-
xfail('nn.functional.hardswish', ''),
11581165
xfail('nn.functional.huber_loss', ''),
11591166
xfail('nn.functional.instance_norm', ''),
11601167
xfail('nn.functional.logsigmoid', ''),
11611168
xfail('nn.functional.pad', 'circular'),
1162-
xfail('nn.functional.prelu', ''),
11631169
xfail('nn.functional.softmin', ''),
11641170
xfail('nn.functional.softmin', 'with_dtype'),
1165-
xfail('nn.functional.softplus', ''),
11661171
xfail('renorm', ''),
11671172
xfail('std_mean', ''),
11681173
xfail('symeig', ''),
@@ -1181,7 +1186,6 @@ def get_vjp(cotangents, *primals):
11811186
xfail('nn.functional.pdist', ''),
11821187
xfail('scatter_reduce', 'sum'),
11831188
xfail('nn.functional.multi_margin_loss', ''),
1184-
xfail('nn.functional.smooth_l1_loss', ''),
11851189
xfail('scatter_reduce', 'mean'),
11861190
xfail('scatter_reduce', 'prod'),
11871191
skip('linalg.householder_product', '', device_type='cuda'), # flaky, I'm not sure why
@@ -1248,8 +1252,6 @@ def reference(primals, cotangents, primals_tangents, cotangents_tangents):
12481252
# in slower tests.
12491253
FUNCTORCH_HAS_FORMULA_BUT_NOT_PYTORCH = {
12501254
'nn.functional.nll_loss',
1251-
'nn.functional.l1_loss',
1252-
'nn.functional.mse_loss',
12531255
'softmax',
12541256
'log_softmax',
12551257
'nn.functional.cross_entropy',
@@ -1322,6 +1324,8 @@ def test_extremal_numerics_l1_loss(self, device):
13221324
cotangents = torch.randn_like(result, device=device)
13231325
self._compare_jacobians_of_vjp(torch.nn.functional.l1_loss, (cotangents, input, target))
13241326

1327+
# ("https://github.com/pytorch/functorch/issues/858")
1328+
@unittest.expectedFailure
13251329
def test_extremal_numerics_mse_loss(self, device):
13261330
N, C, H, W = 3, 4, 5, 6
13271331
shapes = ((N, C), (N, C, H), (N, C, H, W))

test/test_vmap.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3118,6 +3118,9 @@ class TestVmapOperatorsOpInfo(TestCase):
31183118
xfail('pca_lowrank', ''),
31193119
xfail('svd_lowrank', ''),
31203120

3121+
# https://github.com/pytorch/functorch/issues/859
3122+
xfail('__getitem__'),
3123+
31213124
# required rank 4 tensor to use channels_last format
31223125
xfail('bfloat16'),
31233126
xfail('bool'),

0 commit comments

Comments
 (0)