@@ -171,6 +171,27 @@ def wrapped(*args):
171
171
172
172
return wrapped , tangents
173
173
174
+ def get_jvp_variant_primals_tangents (f , sample ):
175
+ # We want this higher-order variant of jvp, so that it can
176
+ # be used to wrap vmap
177
+ fn , primals = normalize_op_input_output (f , sample , requires_grad = False )
178
+ tangents = _as_tuple (
179
+ tree_map (lambda x : torch .randn_like (x ), primals ))
180
+
181
+ @functools .wraps (f )
182
+ def wrapped (* args ):
183
+ primals_in = args [:len (primals )]
184
+ tangents_in = args [len (primals ):]
185
+ primals_out , tangents_out = jvp (fn , primals_in , tangents_in )
186
+
187
+ if isinstance (primals_out , torch .Tensor ):
188
+ return (primals_out , tangents_out )
189
+ else :
190
+ flat_primals_out , _ = tree_flatten (primals_out )
191
+ flat_tangents_out , _ = tree_flatten (tangents_out )
192
+ return tuple (flat_primals_out + flat_tangents_out )
193
+
194
+ return wrapped , primals + tangents
174
195
175
196
def is_inplace (op , variant ):
176
197
if hasattr (variant , "__wrapped__" ):
@@ -596,6 +617,84 @@ def test_vmapjvp(self, device, dtype, op):
596
617
for loop_out , batched_out in get_fallback_and_vmap_exhaustive (fn , args , {}, bdims = (0 ,)):
597
618
self .assertEqual (loop_out , batched_out , atol = 1e-4 , rtol = 1e-4 )
598
619
620
+ @ops (functorch_lagging_op_db , allowed_dtypes = (torch .float ,))
621
+ @skipOps ('TestOperators' , 'test_vmapjvpall' , {
622
+ skip ('nn.functional.dropout' ), # randomness
623
+ skip ('nn.functional.rrelu' ), # randomness
624
+
625
+ # Causing a CUDA assert, needs investigation
626
+ skip ('div' , 'floor_rounding' , device_type = 'cuda' ),
627
+ skip ('div' , 'no_rounding_mode' , device_type = 'cuda' ),
628
+ skip ('div' , 'trunc_rounding' , device_type = 'cuda' ),
629
+ skip ('true_divide' , device_type = 'cuda' ),
630
+
631
+ # xfail list
632
+ xfail ('linalg.inv' ),
633
+ xfail ('masked_fill' ),
634
+ xfail ('__rpow__' ),
635
+ xfail ('logit' ),
636
+ xfail ('linalg.tensorinv' ),
637
+ xfail ('nn.functional.pad' , 'circular' ),
638
+ xfail ('linalg.matrix_power' ),
639
+ xfail ('cumprod' ),
640
+ xfail ('maximum' ),
641
+ xfail ('corrcoef' ),
642
+ xfail ('linalg.householder_product' ),
643
+ xfail ('tensor_split' ),
644
+ xfail ('nn.functional.gelu' ),
645
+ xfail ('quantile' ),
646
+ xfail ('var_mean' ),
647
+ xfail ('index_add' ),
648
+ xfail ('as_strided' ),
649
+ xfail ('linalg.eigvalsh' ),
650
+ xfail ('clamp' , 'scalar' ),
651
+ xfail ('pow' ),
652
+ xfail ('fill_' ),
653
+ xfail ('linalg.cholesky' ),
654
+ xfail ('max' , 'binary' ),
655
+ xfail ('nn.functional.gaussian_nll_loss' ),
656
+ xfail ('min' , 'binary' ),
657
+ xfail ('index_fill' ),
658
+ xfail ('index_put' ),
659
+ xfail ('std_mean' ),
660
+ xfail ('double' , 'channels_last' ),
661
+ xfail ('block_diag' ),
662
+ xfail ('float_power' ),
663
+ xfail ('diag_embed' ),
664
+ xfail ('fmin' ),
665
+ xfail ('minimum' ),
666
+ xfail ('scatter' ),
667
+ xfail ('fmax' ),
668
+ xfail ('matrix_exp' ),
669
+ xfail ('nanquantile' ),
670
+ xfail ('lu' ),
671
+ xfail ('nn.functional.linear' ),
672
+ xfail ('index_copy' ),
673
+ xfail ('masked_scatter' ),
674
+ xfail ('view_as_complex' ),
675
+ })
676
+ # This is technically a superset of test_vmapjvp. We should either delete test_vmapjvp
677
+ # or figure out if we can split vmapjvpall. It's useful to keep test_vmapjvp intact
678
+ # because that coresponds to "batched forward-mode AD" testing in PyTorch core
679
+ def test_vmapjvpall (self , device , dtype , op ):
680
+ if is_inplace (op , op .get_op ()):
681
+ # TODO: test in-place
682
+ self .skipTest ("Skipped! NYI: inplace-testing not supported." )
683
+ return
684
+
685
+ samples = op .sample_inputs (device , dtype , requires_grad = False )
686
+
687
+ if not op .supports_forward_ad :
688
+ self .skipTest ("Skipped! Forward AD not supported." )
689
+ return
690
+
691
+ for sample in samples :
692
+ arg_values = [sample .input ] + list (sample .args )
693
+ kwarg_values = sample .kwargs
694
+ args = tuple ([* arg_values , * kwarg_values ])
695
+ fn , args = get_jvp_variant_primals_tangents (op , sample )
696
+ for loop_out , batched_out in get_fallback_and_vmap_exhaustive (fn , args , {}):
697
+ self .assertEqual (loop_out , batched_out , atol = 1e-4 , rtol = 1e-4 )
599
698
600
699
@ops (functorch_lagging_op_db + additional_op_db , allowed_dtypes = (torch .float ,))
601
700
@skipOps ('TestOperators' , 'test_vmapvjp_has_batch_rule' , vmapvjp_fail .union ({
@@ -839,6 +938,7 @@ class TestDecompositionOpInfo(TestCase):
839
938
skip ('tensor_split' ),
840
939
skip ('mvlgamma' ),
841
940
skip ('tanh' , device_type = 'cuda' ), # cuda bfloat16 failure
941
+ skip ('nn.functional.tanhshrink' , device_type = 'cuda' ), # cuda bfloat16 failure
842
942
skip ('eig' ),
843
943
skip ('nn.functional.dropout' ),
844
944
skip ('_masked.softmin' ),
0 commit comments