Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit a53be5c

Browse files
authored
jvp transform fully composes with vmap (#340)
Test Plan: - new tests
1 parent a0b272a commit a53be5c

File tree

3 files changed

+122
-4
lines changed

3 files changed

+122
-4
lines changed

functorch/csrc/BatchRulesFactory.cpp

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,27 @@ std::tuple<Tensor,optional<int64_t>> _new_zeros_with_same_feature_meta_batch_rul
4040
const Tensor& self, optional<int64_t> self_bdim,
4141
const Tensor& other, optional<int64_t> other_bdim,
4242
int64_t self_num_batch_dims) {
43-
TORCH_CHECK(!other_bdim.has_value(),
44-
"NYI: vmap over jvp of the primal. Please file an issue.");
45-
auto self_ = moveBatchDimToFront(self, self_bdim);
46-
auto result = at::_new_zeros_with_same_feature_meta(self, other, self_num_batch_dims + 1);
43+
// The "self, other" naming is too confusing
44+
// What this function really says is "create a new tangent for this base".
45+
const auto& base = other;
46+
const auto& base_bdim = other_bdim;
47+
const auto& tangent = self;
48+
const auto& tangent_bdim = self_bdim;
49+
50+
// Three case:
51+
// Case 1 Case 2 Case 3
52+
// base [6] [B, 6] [B, 6]
53+
// tangent [B, 5] [5] [B, 5]
54+
55+
// Case 2 & 3: it doesn't matter at all what `tangent` is.
56+
if (base_bdim) {
57+
const auto result = at::_new_zeros_with_same_feature_meta(tangent, base, self_num_batch_dims);
58+
return std::make_tuple(result, base_bdim);
59+
}
60+
61+
// Case 1:
62+
auto tangent_ = moveBatchDimToFront(tangent, tangent_bdim);
63+
auto result = at::_new_zeros_with_same_feature_meta(tangent_, base, self_num_batch_dims + 1);
4764
return std::make_tuple(result, 0);
4865
}
4966

test/test_ops.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,27 @@ def wrapped(*args):
171171

172172
return wrapped, tangents
173173

174+
def get_jvp_variant_primals_tangents(f, sample):
175+
# We want this higher-order variant of jvp, so that it can
176+
# be used to wrap vmap
177+
fn, primals = normalize_op_input_output(f, sample, requires_grad=False)
178+
tangents = _as_tuple(
179+
tree_map(lambda x: torch.randn_like(x), primals))
180+
181+
@functools.wraps(f)
182+
def wrapped(*args):
183+
primals_in = args[:len(primals)]
184+
tangents_in = args[len(primals):]
185+
primals_out, tangents_out = jvp(fn, primals_in, tangents_in)
186+
187+
if isinstance(primals_out, torch.Tensor):
188+
return (primals_out, tangents_out)
189+
else:
190+
flat_primals_out, _ = tree_flatten(primals_out)
191+
flat_tangents_out, _ = tree_flatten(tangents_out)
192+
return tuple(flat_primals_out + flat_tangents_out)
193+
194+
return wrapped, primals + tangents
174195

175196
def is_inplace(op, variant):
176197
if hasattr(variant, "__wrapped__"):
@@ -596,6 +617,84 @@ def test_vmapjvp(self, device, dtype, op):
596617
for loop_out, batched_out in get_fallback_and_vmap_exhaustive(fn, args, {}, bdims=(0,)):
597618
self.assertEqual(loop_out, batched_out, atol=1e-4, rtol=1e-4)
598619

620+
@ops(functorch_lagging_op_db, allowed_dtypes=(torch.float,))
621+
@skipOps('TestOperators', 'test_vmapjvpall', {
622+
skip('nn.functional.dropout'), # randomness
623+
skip('nn.functional.rrelu'), # randomness
624+
625+
# Causing a CUDA assert, needs investigation
626+
skip('div', 'floor_rounding', device_type='cuda'),
627+
skip('div', 'no_rounding_mode', device_type='cuda'),
628+
skip('div', 'trunc_rounding', device_type='cuda'),
629+
skip('true_divide', device_type='cuda'),
630+
631+
# xfail list
632+
xfail('linalg.inv'),
633+
xfail('masked_fill'),
634+
xfail('__rpow__'),
635+
xfail('logit'),
636+
xfail('linalg.tensorinv'),
637+
xfail('nn.functional.pad', 'circular'),
638+
xfail('linalg.matrix_power'),
639+
xfail('cumprod'),
640+
xfail('maximum'),
641+
xfail('corrcoef'),
642+
xfail('linalg.householder_product'),
643+
xfail('tensor_split'),
644+
xfail('nn.functional.gelu'),
645+
xfail('quantile'),
646+
xfail('var_mean'),
647+
xfail('index_add'),
648+
xfail('as_strided'),
649+
xfail('linalg.eigvalsh'),
650+
xfail('clamp', 'scalar'),
651+
xfail('pow'),
652+
xfail('fill_'),
653+
xfail('linalg.cholesky'),
654+
xfail('max', 'binary'),
655+
xfail('nn.functional.gaussian_nll_loss'),
656+
xfail('min', 'binary'),
657+
xfail('index_fill'),
658+
xfail('index_put'),
659+
xfail('std_mean'),
660+
xfail('double', 'channels_last'),
661+
xfail('block_diag'),
662+
xfail('float_power'),
663+
xfail('diag_embed'),
664+
xfail('fmin'),
665+
xfail('minimum'),
666+
xfail('scatter'),
667+
xfail('fmax'),
668+
xfail('matrix_exp'),
669+
xfail('nanquantile'),
670+
xfail('lu'),
671+
xfail('nn.functional.linear'),
672+
xfail('index_copy'),
673+
xfail('masked_scatter'),
674+
xfail('view_as_complex'),
675+
})
676+
# This is technically a superset of test_vmapjvp. We should either delete test_vmapjvp
677+
# or figure out if we can split vmapjvpall. It's useful to keep test_vmapjvp intact
678+
# because that coresponds to "batched forward-mode AD" testing in PyTorch core
679+
def test_vmapjvpall(self, device, dtype, op):
680+
if is_inplace(op, op.get_op()):
681+
# TODO: test in-place
682+
self.skipTest("Skipped! NYI: inplace-testing not supported.")
683+
return
684+
685+
samples = op.sample_inputs(device, dtype, requires_grad=False)
686+
687+
if not op.supports_forward_ad:
688+
self.skipTest("Skipped! Forward AD not supported.")
689+
return
690+
691+
for sample in samples:
692+
arg_values = [sample.input] + list(sample.args)
693+
kwarg_values = sample.kwargs
694+
args = tuple([*arg_values, *kwarg_values])
695+
fn, args = get_jvp_variant_primals_tangents(op, sample)
696+
for loop_out, batched_out in get_fallback_and_vmap_exhaustive(fn, args, {}):
697+
self.assertEqual(loop_out, batched_out, atol=1e-4, rtol=1e-4)
599698

600699
@ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,))
601700
@skipOps('TestOperators', 'test_vmapvjp_has_batch_rule', vmapvjp_fail.union({
@@ -839,6 +938,7 @@ class TestDecompositionOpInfo(TestCase):
839938
skip('tensor_split'),
840939
skip('mvlgamma'),
841940
skip('tanh', device_type='cuda'), # cuda bfloat16 failure
941+
skip('nn.functional.tanhshrink', device_type='cuda'), # cuda bfloat16 failure
842942
skip('eig'),
843943
skip('nn.functional.dropout'),
844944
skip('_masked.softmin'),

test/xfail_suggester.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def get_failed_test(line):
3030
'test_op_has_batch_rule_',
3131
'test_jvp_',
3232
'test_vmapjvp_',
33+
'test_vmapjvpall_',
3334
'test_decomposition_',
3435
'test_make_fx_',
3536
}

0 commit comments

Comments
 (0)