@@ -591,7 +591,6 @@ def vjp_of_vjp(*args_and_cotangents):
591
591
skip ('linalg.svdvals' ), # # really annoying thing where it passes correctness check but not has_batch_rule
592
592
xfail ('__getitem__' , '' ),
593
593
xfail ('_masked.prod' ), # calls aten::item
594
- xfail ('block_diag' ),
595
594
xfail ('eig' ), # calls aten::item
596
595
xfail ('linalg.det' , '' ), # calls .item()
597
596
xfail ('linalg.eig' ), # Uses aten::allclose
@@ -664,7 +663,6 @@ def test_vmapvjp(self, device, dtype, op):
664
663
665
664
# Try to in-place batched tensor into non-batched tensor
666
665
xfail ('matrix_exp' ),
667
- xfail ('block_diag' ), # TODO: We expect this to fail in core, but it doesn't
668
666
669
667
# Apprently these support forward AD, but we get "Trying to use forward AD..."
670
668
# These are cases where OpInfo has supports_forward_ad=True, but disables
@@ -763,7 +761,6 @@ def test_vmapjvp(self, device, dtype, op):
763
761
xfail ('as_strided' ),
764
762
xfail ('nn.functional.gaussian_nll_loss' ),
765
763
xfail ('std_mean' ),
766
- xfail ('block_diag' ),
767
764
xfail ('scatter' ),
768
765
xfail ('matrix_exp' ),
769
766
xfail ('nanquantile' ),
@@ -950,7 +947,6 @@ def test():
950
947
xfail ('to_sparse' ),
951
948
xfail ('unfold' ),
952
949
xfail ('vdot' ),
953
- xfail ('block_diag' ),
954
950
xfail ('nn.functional.dropout' ),
955
951
xfail ('_masked.prod' ),
956
952
xfail ('fft.ihfft2' ),
0 commit comments