@@ -648,92 +648,6 @@ def test_vmapvjp(self, device, dtype, op):
648
648
for loop_out , batched_out in get_fallback_and_vmap_exhaustive (fn , args , {}, opinfo = op ):
649
649
self .assertEqual (loop_out , batched_out )
650
650
651
- # There are several variations we care about
652
- # 1) primal batched (TODO)
653
- # 2) tangent batched (batched grads) <--
654
- # 3) both batched (TODO)
655
- # The below tests (2) only.
656
- @ops (functorch_lagging_op_db , allowed_dtypes = (torch .float ,))
657
- @toleranceOverride ({torch .float32 : tol (atol = 1e-04 , rtol = 1e-04 )})
658
- @skipOps ('TestOperators' , 'test_vmapjvp' , {
659
- skip ('nn.functional.dropout' ), # randomness
660
- skip ('nn.functional.rrelu' ), # randomness
661
- skip ('nn.functional.fractional_max_pool2d' ), # randomness
662
- skip ('nn.functional.fractional_max_pool3d' ), # randomness
663
- skip ('bernoulli' , '' ), # randomness
664
- skip ('nn.functional.max_pool1d' ), # fails on cpu, runs on cuda
665
-
666
- # TODO: fails in core due to in-place batched nto non-batched
667
- # but fails here for a different reason
668
- xfail ('linalg.householder_product' ),
669
-
670
- # Try to in-place batched tensor into non-batched tensor
671
- xfail ('matrix_exp' ),
672
-
673
- # Apprently these support forward AD, but we get "Trying to use forward AD..."
674
- # These are cases where OpInfo has supports_forward_ad=True, but disables
675
- # the test
676
- xfail ('var_mean' ),
677
- xfail ('std_mean' ),
678
-
679
- # RuntimeError: expand: the number of sizes provided (1) must be greater or
680
- # equal to the number of dimensions in the tensor (2)
681
- xfail ('nanquantile' ),
682
- xfail ('quantile' ),
683
-
684
- # Not implemented
685
- xfail ('scatter' ),
686
-
687
- # =============================================
688
- # NB: The above failures also fail in PyTorch core.
689
- # The failures below only fail in functorch
690
- # =============================================
691
-
692
- # Composite ops that do bad things. Need to be fixed in PyTorch core.
693
- # RuntimeError: Cannot access data pointer of Tensor that doesn't have storage
694
- xfail ('tensor_split' ),
695
-
696
- # Causing multiple forward mode AD issues, needs investigation
697
- xfail ('nn.functional.batch_norm' ),
698
- xfail ('nn.functional.batch_norm' , 'without_cudnn' , device_type = 'cuda' ),
699
-
700
- skip ('nn.functional.feature_alpha_dropout' , 'with_train' ),
701
- skip ('pca_lowrank' , '' ),
702
- skip ('nn.functional.dropout2d' , '' ),
703
- skip ('nn.functional.feature_alpha_dropout' , 'without_train' ),
704
- skip ('svd_lowrank' , '' ),
705
- xfail ('nn.functional.soft_margin_loss' , '' ),
706
- xfail ('stft' ), # something weird is happening with shapes
707
-
708
- xfail ('double' ), # required rank 4 tensor to use channels_last format
709
-
710
- # BUG: runs and produces numerical differences
711
- skip ('nn.functional.max_unpool1d' , device_type = 'cpu' ), # fails everywhere except on mac
712
- skip ('nn.functional.max_unpool2d' ), # fails everywhere except on mac
713
- skip ('nn.functional.max_unpool3d' ), # fails everywhere except on mac
714
-
715
- xfail ('put' ), # calls put_ during vmap with only vmaps over other, not self
716
- })
717
- def test_vmapjvp (self , device , dtype , op ):
718
- if is_inplace (op , op .get_op ()):
719
- # TODO: test in-place
720
- self .skipTest ("Skipped! NYI: inplace-testing not supported." )
721
- return
722
-
723
- samples = op .sample_inputs (device , dtype , requires_grad = False )
724
-
725
- if not op .supports_forward_ad :
726
- self .skipTest ("Skipped! Forward AD not supported." )
727
- return
728
-
729
- for sample in samples :
730
- arg_values = [sample .input ] + list (sample .args )
731
- kwarg_values = sample .kwargs
732
- args = tuple ([* arg_values , * kwarg_values ])
733
- fn , args = get_jvp_variant (op , sample )
734
- for loop_out , batched_out in get_fallback_and_vmap_exhaustive (fn , args , {}, opinfo = op , bdims = (0 ,)):
735
- self .assertEqual (loop_out , batched_out )
736
-
737
651
vmapjvpall_fail = {
738
652
# The following are expected (not a bug)
739
653
skip ('bernoulli' , '' ), # randomness
@@ -757,7 +671,8 @@ def test_vmapjvp(self, device, dtype, op):
757
671
758
672
# Not actually a problem: embedding with max_norm mutates the weight
759
673
# and causes different runs to produce different results.
760
- xfail ('nn.functional.embedding' , '' ),
674
+ # skip because this is flaky depending on what the max_norm is!
675
+ skip ('nn.functional.embedding' , '' ),
761
676
xfail ('nn.functional.soft_margin_loss' , '' ),
762
677
xfail ('nn.functional.binary_cross_entropy_with_logits' , '' ),
763
678
xfail ('linalg.householder_product' ),
@@ -788,7 +703,7 @@ def test_vmapjvp(self, device, dtype, op):
788
703
xfail ('nn.functional.prelu' ), # Call Tensor.as_strided
789
704
}
790
705
791
- @ops (functorch_lagging_op_db , allowed_dtypes = (torch .float ,))
706
+ @ops (functorch_lagging_op_db + additional_op_db , allowed_dtypes = (torch .float ,))
792
707
@opsToleranceOverride ('TestOperators' , 'test_vmapjvpall' , (
793
708
tol1 ('nn.functional.conv_transpose3d' ,
794
709
{torch .float32 : tol (atol = 2e-04 , rtol = 9e-3 )}, device_type = 'cuda' ),
@@ -818,7 +733,7 @@ def test_vmapjvpall(self, device, dtype, op):
818
733
for loop_out , batched_out in get_fallback_and_vmap_exhaustive (fn , args , {}, opinfo = op ):
819
734
self .assertEqual (loop_out , batched_out )
820
735
821
- @ops (functorch_lagging_op_db , allowed_dtypes = (torch .float ,))
736
+ @ops (functorch_lagging_op_db + additional_op_db , allowed_dtypes = (torch .float ,))
822
737
@skipOps ('TestOperators' , 'test_vmapjvpall_has_batch_rule' , vmapjvpall_fail .union ({
823
738
xfail ('linalg.solve_triangular' ),
824
739
xfail ('nn.functional.huber_loss' ),
0 commit comments