Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit 88eae98

Browse files
authored
Cleanup jvp testing (#890)
- deleted test_vmapjvp (test_vmapjvpall is a superset of test_vmapjvp) - added additional_op_db to things tested by test_vmapjvpall - did some accounting in discover_coverage
1 parent f236d02 commit 88eae98

File tree

3 files changed

+12
-93
lines changed

3 files changed

+12
-93
lines changed

test/discover_coverage.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -628,8 +628,6 @@ class Support(enum.Enum):
628628
'nn.functional.rrelu', # not actually problem, randomness testing artifact
629629
'normal', # not actually problem, randomness testing artifact
630630
'bernoulli', # not actually problem, randomness testing artifact
631-
'nn.functional.embedding', # max_norm causes testing to be weird
632-
# 'multinomial',
633631
}
634632

635633

@@ -746,8 +744,12 @@ def _supports_vmapjvp_base(self, test):
746744
'nn.functional.batch_norm', # testing problem
747745
'normal', # not actually problem, randomness testing artifact
748746
'bernoulli', # not actually problem, randomness testing artifact
749-
'dropout2d', # not actually problem, randomness testing artifact
750-
'dropout', # not actually problem, randomness testing artifact
747+
'nn.functional.dropout2d', # not actually problem, randomness testing artifact
748+
'nn.functional.dropout', # not actually problem, randomness testing artifact
749+
# Not a problem.
750+
# It's just that the max_norm testing mutates inputs...
751+
# (we have our own functorch variant of the OpInfo without max_norm)
752+
'nn.functional.embedding',
751753
}
752754
if self.name in VMAPJVP_EXEMPTIONS:
753755
return Support.YES

test/functorch_additional_op_db.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,8 @@ def generator():
181181
op=lambda weight, idx, **kwargs: torch.nn.functional.embedding(idx, weight, **kwargs),
182182
dtypes=floating_types_and(torch.bfloat16, torch.float16),
183183
sample_inputs_func=sample_inputs_embedding,
184+
supports_forward_ad=True,
185+
supports_fwgrad_bwgrad=True,
184186
supports_out=False,
185187
))
186188

test/test_ops.py

Lines changed: 4 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -648,92 +648,6 @@ def test_vmapvjp(self, device, dtype, op):
648648
for loop_out, batched_out in get_fallback_and_vmap_exhaustive(fn, args, {}, opinfo=op):
649649
self.assertEqual(loop_out, batched_out)
650650

651-
# There are several variations we care about
652-
# 1) primal batched (TODO)
653-
# 2) tangent batched (batched grads) <--
654-
# 3) both batched (TODO)
655-
# The below tests (2) only.
656-
@ops(functorch_lagging_op_db, allowed_dtypes=(torch.float,))
657-
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
658-
@skipOps('TestOperators', 'test_vmapjvp', {
659-
skip('nn.functional.dropout'), # randomness
660-
skip('nn.functional.rrelu'), # randomness
661-
skip('nn.functional.fractional_max_pool2d'), # randomness
662-
skip('nn.functional.fractional_max_pool3d'), # randomness
663-
skip('bernoulli', ''), # randomness
664-
skip('nn.functional.max_pool1d'), # fails on cpu, runs on cuda
665-
666-
# TODO: fails in core due to in-place batched nto non-batched
667-
# but fails here for a different reason
668-
xfail('linalg.householder_product'),
669-
670-
# Try to in-place batched tensor into non-batched tensor
671-
xfail('matrix_exp'),
672-
673-
# Apprently these support forward AD, but we get "Trying to use forward AD..."
674-
# These are cases where OpInfo has supports_forward_ad=True, but disables
675-
# the test
676-
xfail('var_mean'),
677-
xfail('std_mean'),
678-
679-
# RuntimeError: expand: the number of sizes provided (1) must be greater or
680-
# equal to the number of dimensions in the tensor (2)
681-
xfail('nanquantile'),
682-
xfail('quantile'),
683-
684-
# Not implemented
685-
xfail('scatter'),
686-
687-
# =============================================
688-
# NB: The above failures also fail in PyTorch core.
689-
# The failures below only fail in functorch
690-
# =============================================
691-
692-
# Composite ops that do bad things. Need to be fixed in PyTorch core.
693-
# RuntimeError: Cannot access data pointer of Tensor that doesn't have storage
694-
xfail('tensor_split'),
695-
696-
# Causing multiple forward mode AD issues, needs investigation
697-
xfail('nn.functional.batch_norm'),
698-
xfail('nn.functional.batch_norm', 'without_cudnn', device_type='cuda'),
699-
700-
skip('nn.functional.feature_alpha_dropout', 'with_train'),
701-
skip('pca_lowrank', ''),
702-
skip('nn.functional.dropout2d', ''),
703-
skip('nn.functional.feature_alpha_dropout', 'without_train'),
704-
skip('svd_lowrank', ''),
705-
xfail('nn.functional.soft_margin_loss', ''),
706-
xfail('stft'), # something weird is happening with shapes
707-
708-
xfail('double'), # required rank 4 tensor to use channels_last format
709-
710-
# BUG: runs and produces numerical differences
711-
skip('nn.functional.max_unpool1d', device_type='cpu'), # fails everywhere except on mac
712-
skip('nn.functional.max_unpool2d'), # fails everywhere except on mac
713-
skip('nn.functional.max_unpool3d'), # fails everywhere except on mac
714-
715-
xfail('put'), # calls put_ during vmap with only vmaps over other, not self
716-
})
717-
def test_vmapjvp(self, device, dtype, op):
718-
if is_inplace(op, op.get_op()):
719-
# TODO: test in-place
720-
self.skipTest("Skipped! NYI: inplace-testing not supported.")
721-
return
722-
723-
samples = op.sample_inputs(device, dtype, requires_grad=False)
724-
725-
if not op.supports_forward_ad:
726-
self.skipTest("Skipped! Forward AD not supported.")
727-
return
728-
729-
for sample in samples:
730-
arg_values = [sample.input] + list(sample.args)
731-
kwarg_values = sample.kwargs
732-
args = tuple([*arg_values, *kwarg_values])
733-
fn, args = get_jvp_variant(op, sample)
734-
for loop_out, batched_out in get_fallback_and_vmap_exhaustive(fn, args, {}, opinfo=op, bdims=(0,)):
735-
self.assertEqual(loop_out, batched_out)
736-
737651
vmapjvpall_fail = {
738652
# The following are expected (not a bug)
739653
skip('bernoulli', ''), # randomness
@@ -757,7 +671,8 @@ def test_vmapjvp(self, device, dtype, op):
757671

758672
# Not actually a problem: embedding with max_norm mutates the weight
759673
# and causes different runs to produce different results.
760-
xfail('nn.functional.embedding', ''),
674+
# skip because this is flaky depending on what the max_norm is!
675+
skip('nn.functional.embedding', ''),
761676
xfail('nn.functional.soft_margin_loss', ''),
762677
xfail('nn.functional.binary_cross_entropy_with_logits', ''),
763678
xfail('linalg.householder_product'),
@@ -788,7 +703,7 @@ def test_vmapjvp(self, device, dtype, op):
788703
xfail('nn.functional.prelu'), # Call Tensor.as_strided
789704
}
790705

791-
@ops(functorch_lagging_op_db, allowed_dtypes=(torch.float,))
706+
@ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,))
792707
@opsToleranceOverride('TestOperators', 'test_vmapjvpall', (
793708
tol1('nn.functional.conv_transpose3d',
794709
{torch.float32: tol(atol=2e-04, rtol=9e-3)}, device_type='cuda'),
@@ -818,7 +733,7 @@ def test_vmapjvpall(self, device, dtype, op):
818733
for loop_out, batched_out in get_fallback_and_vmap_exhaustive(fn, args, {}, opinfo=op):
819734
self.assertEqual(loop_out, batched_out)
820735

821-
@ops(functorch_lagging_op_db, allowed_dtypes=(torch.float,))
736+
@ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,))
822737
@skipOps('TestOperators', 'test_vmapjvpall_has_batch_rule', vmapjvpall_fail.union({
823738
xfail('linalg.solve_triangular'),
824739
xfail('nn.functional.huber_loss'),

0 commit comments

Comments
 (0)