Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit b94ece0

Browse files
authored
Beef up jvpvjp testing (#648)
1 parent c5cffb8 commit b94ece0

File tree

3 files changed

+172
-82
lines changed

3 files changed

+172
-82
lines changed

test/discover_coverage.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -677,6 +677,11 @@ def supports_jvp(self):
677677
return Support.NO
678678
return self.no_opinfos_skip_test('test_jvp')
679679

680+
def supports_jvpvjp(self):
681+
if self.name in FACTORY_FNS:
682+
return Support.YES
683+
return self.no_opinfos_skip_test('test_jvpvjp')
684+
680685
def _supports_vmapjvp_base(self, test):
681686
if self.name in FACTORY_FNS:
682687
return Support.YES
@@ -755,6 +760,7 @@ def summary(self):
755760
'supports_jvp',
756761
'supports_vmapjvp',
757762
'supports_fast_vmapjvp',
763+
'supports_jvpvjp',
758764
]
759765
result = ['test, yes, no, unknown']
760766
for check in checks:
@@ -779,9 +785,10 @@ def summary(self):
779785

780786
print("=" * 30 + " Top 125 Summary " + "=" * 30)
781787
opset = OperatorSet.from_top125()
782-
result = opset.query(Operator.supports_vjp, (Support.NO, Support.UNKNOWN))
788+
result = opset.query(Operator.supports_jvp, (Support.NO, Support.UNKNOWN))
789+
pprint.pprint(result)
790+
result = opset.query(Operator.supports_jvpvjp, (Support.NO, Support.UNKNOWN))
783791
pprint.pprint(result)
784-
result = opset.query(Operator.supports_vmapjvp, (Support.NO, Support.UNKNOWN))
785792
# pprint.pprint(result)
786793
print(opset.summary())
787794

test/test_ops.py

Lines changed: 117 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,69 +1055,109 @@ def test_vjpvmap(self, device, dtype, op):
10551055
skip('nn.functional.fractional_max_pool2d'), # Random
10561056
skip('nn.functional.fractional_max_pool3d'), # Random
10571057
1058-
xfail('_masked.log_softmax'),
1059-
xfail('_masked.softmax'),
1060-
xfail('_masked.softmin'),
1061-
xfail('block_diag'),
1062-
xfail('cdist'),
1063-
xfail('fft.fft'),
1064-
xfail('fft.fft2'),
1065-
xfail('fft.fftn'),
1066-
xfail('fft.hfft'),
1067-
xfail('fft.hfft2'),
1068-
xfail('fft.hfftn'),
1069-
xfail('fft.ifft'),
1070-
xfail('fft.ifft2'),
1071-
xfail('fft.ifftn'),
1072-
xfail('fft.ihfft'),
1073-
xfail('fft.ihfft2'),
1074-
xfail('fft.ihfftn'),
1075-
xfail('fft.irfft'),
1076-
xfail('fft.irfft2'),
1077-
xfail('fft.irfftn'),
1078-
xfail('fft.rfft'),
1079-
xfail('fft.rfft2'),
1080-
xfail('fft.rfftn'),
1081-
xfail('istft'),
1082-
xfail('log_softmax'),
1058+
xfail('__rsub__', ''),
1059+
xfail('_masked.amax', ''),
1060+
xfail('_masked.amin', ''),
1061+
xfail('_masked.log_softmax', ''),
1062+
xfail('_masked.norm', ''),
1063+
xfail('_masked.normalize', ''),
1064+
xfail('_masked.softmax', ''),
1065+
xfail('_masked.softmin', ''),
1066+
xfail('amax', ''),
1067+
xfail('amin', ''),
1068+
xfail('atan2', ''),
1069+
xfail('block_diag', ''),
1070+
xfail('cdist', ''),
1071+
xfail('cholesky', ''),
1072+
xfail('cholesky_inverse', ''),
1073+
xfail('dist', ''),
1074+
xfail('eig', ''),
1075+
xfail('fft.fft', ''),
1076+
xfail('fft.fft2', ''),
1077+
xfail('fft.fftn', ''),
1078+
xfail('fft.hfft', ''),
1079+
xfail('fft.hfft2', ''),
1080+
xfail('fft.hfftn', ''),
1081+
xfail('fft.ifft', ''),
1082+
xfail('fft.ifft2', ''),
1083+
xfail('fft.ifftn', ''),
1084+
xfail('fft.ihfft', ''),
1085+
xfail('fft.ihfft2', ''),
1086+
xfail('fft.ihfftn', ''),
1087+
xfail('fft.irfft', ''),
1088+
xfail('fft.irfft2', ''),
1089+
xfail('fft.irfftn', ''),
1090+
xfail('fft.rfft', ''),
1091+
xfail('fft.rfft2', ''),
1092+
xfail('fft.rfftn', ''),
1093+
xfail('istft', ''),
1094+
xfail('linalg.det', ''),
1095+
xfail('linalg.eigh', ''),
1096+
xfail('linalg.eigvalsh', ''),
1097+
xfail('linalg.matrix_norm', ''),
1098+
xfail('linalg.norm', ''),
1099+
xfail('linalg.slogdet', ''),
1100+
xfail('linalg.vector_norm', ''),
1101+
xfail('log_softmax', ''),
10831102
xfail('log_softmax', 'dtype'),
1084-
xfail('logcumsumexp'),
1085-
xfail('nn.functional.batch_norm'),
1103+
xfail('logcumsumexp', ''),
1104+
xfail('logdet', ''),
1105+
xfail('lu', ''),
1106+
xfail('lu_solve', ''),
1107+
xfail('lu_unpack', ''),
1108+
xfail('max', 'binary'),
1109+
xfail('maximum', ''),
1110+
xfail('min', 'binary'),
1111+
xfail('minimum', ''),
1112+
xfail('nanmean', ''),
1113+
xfail('nansum', ''),
1114+
xfail('nn.functional.batch_norm', ''),
10861115
xfail('nn.functional.batch_norm', 'without_cudnn', device_type='cuda'),
1087-
xfail('nn.functional.bilinear'),
1088-
xfail('nn.functional.binary_cross_entropy'),
1089-
xfail('nn.functional.binary_cross_entropy_with_logits', device_type='cuda'),
1090-
xfail('nn.functional.celu'),
1091-
xfail('nn.functional.cross_entropy'),
1116+
xfail('nn.functional.bilinear', ''),
1117+
xfail('nn.functional.binary_cross_entropy', ''),
1118+
xfail('nn.functional.binary_cross_entropy_with_logits', ''),
1119+
xfail('nn.functional.celu', ''),
1120+
xfail('nn.functional.cross_entropy', ''),
10921121
xfail('nn.functional.cross_entropy', 'mean'),
10931122
xfail('nn.functional.cross_entropy', 'none'),
10941123
xfail('nn.functional.cross_entropy', 'sum'),
1095-
xfail('nn.functional.elu'),
1096-
xfail('nn.functional.embedding'),
1124+
xfail('nn.functional.elu', ''),
1125+
xfail('nn.functional.embedding', ''),
10971126
xfail('nn.functional.embedding', 'functorch'),
1098-
xfail('nn.functional.embedding_bag'),
1099-
xfail('nn.functional.glu'),
1100-
xfail('nn.functional.grid_sample'),
1101-
xfail('nn.functional.hardsigmoid'),
1102-
xfail('nn.functional.hardswish'),
1103-
xfail('nn.functional.huber_loss'),
1104-
xfail('nn.functional.instance_norm'),
1105-
xfail('nn.functional.layer_norm'),
1106-
xfail('nn.functional.leaky_relu'),
1107-
xfail('nn.functional.logsigmoid'),
1108-
xfail('nn.functional.mse_loss'),
1109-
xfail('nn.functional.nll_loss'),
1127+
xfail('nn.functional.embedding_bag', ''),
1128+
xfail('nn.functional.glu', ''),
1129+
xfail('nn.functional.grid_sample', ''),
1130+
xfail('nn.functional.hardsigmoid', ''),
1131+
xfail('nn.functional.hardswish', ''),
1132+
xfail('nn.functional.huber_loss', ''),
1133+
xfail('nn.functional.instance_norm', ''),
1134+
xfail('nn.functional.layer_norm', ''),
1135+
xfail('nn.functional.leaky_relu', ''),
1136+
xfail('nn.functional.logsigmoid', ''),
1137+
xfail('nn.functional.mse_loss', ''),
1138+
xfail('nn.functional.nll_loss', ''),
1139+
xfail('nn.functional.normalize', ''),
11101140
xfail('nn.functional.pad', 'circular'),
1111-
xfail('nn.functional.prelu'),
1112-
xfail('nn.functional.selu'),
1113-
xfail('nn.functional.softmin'),
1141+
xfail('nn.functional.pairwise_distance', ''),
1142+
xfail('nn.functional.prelu', ''),
1143+
xfail('nn.functional.selu', ''),
1144+
xfail('nn.functional.softmin', ''),
11141145
xfail('nn.functional.softmin', 'with_dtype'),
1115-
xfail('nn.functional.softplus'),
1116-
xfail('put'),
1117-
xfail('softmax'),
1146+
xfail('nn.functional.softplus', ''),
1147+
xfail('norm', ''),
1148+
xfail('norm', 'fro'),
1149+
xfail('norm', 'inf'),
1150+
xfail('polar', ''),
1151+
xfail('put', ''),
1152+
xfail('renorm', ''),
1153+
xfail('softmax', ''),
11181154
xfail('softmax', 'with_dtype'),
1119-
xfail('stft'),
1120-
xfail('take'),
1155+
xfail('solve', ''),
1156+
xfail('std_mean', ''),
1157+
xfail('stft', ''),
1158+
xfail('symeig', ''),
1159+
xfail('take', ''),
1160+
xfail('var_mean', ''),
11211161
}))
11221162
def test_jvpvjp(self, device, dtype, op):
11231163
if not op.supports_autograd:
@@ -1135,28 +1175,40 @@ def test_jvpvjp(self, device, dtype, op):
11351175
fn, primals = normalize_op_input_output(op, sample)
11361176
result = fn(*primals)
11371177
cotangents = tree_map(lambda x: torch.randn_like(x), result)
1138-
tangents = tree_map(lambda x: torch.randn_like(x), result)
11391178

1140-
_, vjp_fn = vjp(fn, *primals)
1141-
result = jvp(vjp_fn, (cotangents,), (tangents,))
1179+
primals_tangents = tree_map(lambda x: torch.randn_like(x), primals)
1180+
cotangents_tangents = tree_map(lambda x: torch.randn_like(x), cotangents)
1181+
1182+
def push_vjp(primals, cotangents):
1183+
_, vjp_fn = vjp(fn, *primals)
1184+
return vjp_fn(cotangents)
1185+
1186+
result = jvp(push_vjp, (primals, cotangents), (primals_tangents, cotangents_tangents))
11421187
self.assertEqual(len(result), 2)
11431188

1144-
def reference(primals, cotangents, tangents):
1145-
_, vjp_fn = ref_vjp(fn, *primals)
1189+
def tree_map2(fn, first, second):
1190+
flat_first, spec_first = tree_flatten(first)
1191+
flat_second, spec_second = tree_flatten(second)
1192+
assert spec_first == spec_second
1193+
flat_result = [fn(f, s) for f, s in zip(flat_first, flat_second)]
1194+
return tree_unflatten(flat_result, spec_first)
1195+
1196+
def reference(primals, cotangents, primals_tangents, cotangents_tangents):
11461197
with fwAD.dual_level():
1147-
flat_cotangents, spec = tree_flatten(cotangents)
1148-
flat_tangents, spec = tree_flatten(tangents)
1149-
flat_duals = [fwAD.make_dual(c, t) for c, t in zip(flat_cotangents, flat_tangents)]
1150-
duals = tree_unflatten(flat_duals, spec)
1151-
result = vjp_fn(duals)
1198+
primal_duals = tree_map2(fwAD.make_dual, primals, primals_tangents)
1199+
_, vjp_fn = ref_vjp(fn, *primal_duals)
1200+
1201+
cotangent_duals = tree_map2(fwAD.make_dual, cotangents, cotangents_tangents)
1202+
result = vjp_fn(cotangent_duals)
1203+
11521204
flat_result, spec = tree_flatten(result)
11531205
primals_out, tangents_out = zip(*[fwAD.unpack_dual(r) for r in flat_result])
11541206
tangents_out = [t if t is not None else torch.zeros_like(p)
11551207
for p, t in zip(primals_out, tangents_out)]
11561208
expected = (tree_unflatten(primals_out, spec), tree_unflatten(tangents_out, spec))
11571209
return expected
11581210

1159-
expected = reference(primals, cotangents, tangents)
1211+
expected = reference(primals, cotangents, primals_tangents, cotangents_tangents)
11601212
self.assertEqual(result, expected)
11611213

11621214

test/xfail_suggester.py

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import re
2+
import torch
23

34
"""
45
Instructions:
@@ -58,16 +59,42 @@ def belongs_to_base(test, base):
5859
return True
5960

6061

61-
def sanitize_base(base):
62-
if base.startswith('nn_functional_'):
63-
base = f'nn.functional.{base[len("nn_functional_"):]}'
64-
if base.startswith('fft_'):
65-
base = f'fft.{base[len("fft_"):]}'
66-
if base.startswith('linalg_'):
67-
base = f'linalg.{base[len("linalg."):]}'
68-
if base.startswith('_masked_'):
69-
base = f'_masked.{base[len("_masked_"):]}'
70-
return base
62+
def parse_namespace(base):
63+
mappings = {
64+
'nn_functional_': 'nn.functional',
65+
'fft_': 'fft',
66+
'linalg_': 'linalg',
67+
'_masked_': '_masked',
68+
}
69+
for heading in mappings.keys():
70+
if base.startswith(heading):
71+
return mappings[heading], base[len(heading):]
72+
return None, base
73+
74+
75+
def get_torch_module(namespace):
76+
if namespace is None:
77+
return torch
78+
if namespace == 'nn.functional':
79+
return torch.nn.functional
80+
return getattr(torch, namespace)
81+
82+
83+
def parse_base(base):
84+
namespace, rest = parse_namespace(base)
85+
86+
apis = dir(get_torch_module(namespace))
87+
apis = sorted(apis, key=lambda x: -len(x))
88+
89+
api = rest
90+
variant = ''
91+
for candidate in apis:
92+
if rest.startswith(candidate):
93+
api = candidate
94+
variant = rest[len(candidate) + 1:]
95+
break
96+
print(base, namespace, api, variant)
97+
return namespace, api, variant
7198

7299

73100
def any_starts_with(strs, thing):
@@ -87,17 +114,21 @@ def get_suggested_xfails(base, tests):
87114
for base in base_tests:
88115
cpu_variant = base + '_cpu_float32'
89116
cuda_variant = base + '_cuda_float32'
90-
sanitized_base = sanitize_base(base)
117+
namespace, api, variant = parse_base(base)
118+
if namespace is None:
119+
api = api
120+
else:
121+
api = f'{namespace}.{api}'
91122
if cpu_variant in tests and cuda_variant in tests:
92-
result.append(f"xfail('{sanitized_base}'),")
123+
result.append(f"xfail('{api}', '{variant}'),")
93124
continue
94125
if cpu_variant in tests:
95-
result.append(f"xfail('{sanitized_base}', device_type='cpu'),")
126+
result.append(f"xfail('{api}', '{variant}', device_type='cpu'),")
96127
continue
97128
if cuda_variant in tests:
98-
result.append(f"xfail('{sanitized_base}', device_type='cuda'),")
129+
result.append(f"xfail('{api}', '{variant}', device_type='cuda'),")
99130
continue
100-
result.append(f"skip('{sanitized_base}'),")
131+
result.append(f"skip('{api}', '{variant}',")
101132
return result
102133

103134

0 commit comments

Comments
 (0)