Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit c5ce4d0

Browse files
authored
jvp x vjp testing (#343)
Failures are either: - lack of PyTorch forward-mode AD support (mostly) - efficient zero tensors errors - CUDA asserts (really need to be investigated). All of the problems should be reproducible on the pytorch/pytorch side.
1 parent 26c1847 commit c5ce4d0

File tree

2 files changed

+115
-0
lines changed

2 files changed

+115
-0
lines changed

test/test_ops.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,6 +1019,120 @@ def test_vjpvmap(self, device, dtype, op):
10191019

10201020
self.assertEqual(result_vjps, expected_vjps)
10211021

1022+
@ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,))
1023+
@skipOps('TestOperators', 'test_jvpvjp', vjp_fail.union({
1024+
# These are weirdly non-deterministic
1025+
skip('nn.functional.conv2d', '', device_type='cpu'),
1026+
skip('nn.functional.conv2d', 'no_bias', device_type='cpu'),
1027+
skip('nn.functional.conv2d', 'stride_no_bias', device_type='cpu'),
1028+
skip('nn.functional.conv2d', 'stride_padding_no_bias', device_type='cpu'),
1029+
skip('nn.functional.fractional_max_pool2d'), # Random
1030+
skip('nn.functional.fractional_max_pool3d'), # Random
1031+
1032+
xfail('_masked.log_softmax'),
1033+
xfail('_masked.softmax'),
1034+
xfail('_masked.softmin'),
1035+
xfail('block_diag'),
1036+
xfail('cdist'),
1037+
xfail('fft.fft'),
1038+
xfail('fft.fft2'),
1039+
xfail('fft.fftn'),
1040+
xfail('fft.hfft'),
1041+
xfail('fft.hfft2'),
1042+
xfail('fft.hfftn'),
1043+
xfail('fft.ifft'),
1044+
xfail('fft.ifft2'),
1045+
xfail('fft.ifftn'),
1046+
xfail('fft.ihfft'),
1047+
xfail('fft.ihfft2'),
1048+
xfail('fft.ihfftn'),
1049+
xfail('fft.irfft'),
1050+
xfail('fft.irfft2'),
1051+
xfail('fft.irfftn'),
1052+
xfail('fft.rfft'),
1053+
xfail('fft.rfft2'),
1054+
xfail('fft.rfftn'),
1055+
xfail('istft'),
1056+
xfail('log_softmax'),
1057+
xfail('log_softmax', 'dtype'),
1058+
xfail('logcumsumexp'),
1059+
xfail('nn.functional.batch_norm'),
1060+
xfail('nn.functional.batch_norm', 'without_cudnn', device_type='cuda'),
1061+
xfail('nn.functional.bilinear'),
1062+
xfail('nn.functional.binary_cross_entropy'),
1063+
xfail('nn.functional.binary_cross_entropy_with_logits', device_type='cuda'),
1064+
xfail('nn.functional.celu'),
1065+
xfail('nn.functional.cross_entropy'),
1066+
xfail('nn.functional.cross_entropy', 'mean'),
1067+
xfail('nn.functional.cross_entropy', 'none'),
1068+
xfail('nn.functional.cross_entropy', 'sum'),
1069+
xfail('nn.functional.elu'),
1070+
xfail('nn.functional.embedding'),
1071+
xfail('nn.functional.embedding', 'functorch'),
1072+
xfail('nn.functional.embedding_bag'),
1073+
xfail('nn.functional.glu'),
1074+
xfail('nn.functional.grid_sample'),
1075+
xfail('nn.functional.hardsigmoid'),
1076+
xfail('nn.functional.hardswish'),
1077+
xfail('nn.functional.huber_loss'),
1078+
xfail('nn.functional.instance_norm'),
1079+
xfail('nn.functional.layer_norm'),
1080+
xfail('nn.functional.leaky_relu'),
1081+
xfail('nn.functional.logsigmoid'),
1082+
xfail('nn.functional.mse_loss'),
1083+
xfail('nn.functional.nll_loss'),
1084+
xfail('nn.functional.pad', 'circular'),
1085+
xfail('nn.functional.prelu'),
1086+
xfail('nn.functional.selu'),
1087+
xfail('nn.functional.softmin'),
1088+
xfail('nn.functional.softmin', 'with_dtype'),
1089+
xfail('nn.functional.softplus'),
1090+
xfail('put'),
1091+
xfail('softmax'),
1092+
xfail('softmax', 'with_dtype'),
1093+
xfail('stft'),
1094+
xfail('take'),
1095+
}))
1096+
def test_jvpvjp(self, device, dtype, op):
1097+
if not op.supports_autograd:
1098+
self.skipTest("Skipped! Autograd not supported.")
1099+
return
1100+
1101+
samples = op.sample_inputs(device, dtype, requires_grad=True)
1102+
1103+
# TODO: test in-place
1104+
if is_inplace(op, op.get_op()):
1105+
self.skipTest("Skipped! NYI: inplace-testing not supported.")
1106+
return
1107+
1108+
for sample in samples:
1109+
fn, primals = normalize_op_input_output(op, sample)
1110+
result = fn(*primals)
1111+
cotangents = tree_map(lambda x: torch.randn_like(x), result)
1112+
tangents = tree_map(lambda x: torch.randn_like(x), result)
1113+
1114+
_, vjp_fn = vjp(fn, *primals)
1115+
result = jvp(vjp_fn, (cotangents,), (tangents,))
1116+
self.assertEqual(len(result), 2)
1117+
1118+
def reference(primals, cotangents, tangents):
1119+
_, vjp_fn = ref_vjp(fn, *primals)
1120+
with fwAD.dual_level():
1121+
flat_cotangents, spec = tree_flatten(cotangents)
1122+
flat_tangents, spec = tree_flatten(tangents)
1123+
flat_duals = [fwAD.make_dual(c, t) for c, t in zip(flat_cotangents, flat_tangents)]
1124+
duals = tree_unflatten(flat_duals, spec)
1125+
result = vjp_fn(duals)
1126+
flat_result, spec = tree_flatten(result)
1127+
primals_out, tangents_out = zip(*[fwAD.unpack_dual(r) for r in flat_result])
1128+
tangents_out = [t if t is not None else torch.zeros_like(p)
1129+
for p, t in zip(primals_out, tangents_out)]
1130+
expected = (tree_unflatten(primals_out, spec), tree_unflatten(tangents_out, spec))
1131+
return expected
1132+
1133+
expected = reference(primals, cotangents, tangents)
1134+
self.assertEqual(result, expected)
1135+
10221136

10231137
class InplaceError(Exception):
10241138
def __repr__(self):

test/xfail_suggester.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def get_failed_test(line):
3232
'test_jvp_',
3333
'test_vmapjvp_',
3434
'test_vmapjvpall_',
35+
'test_jvpvjp_',
3536
'test_decomposition_',
3637
'test_make_fx_',
3738
}

0 commit comments

Comments
 (0)