Skip to content

Commit e8ab067

Browse files
authored
Fix vmapvjp tests (#354)
Previously they were not vmapping over non-differentiable arguments. That case is important to test. Test Plan: - run tests
1 parent 5d043be commit e8ab067

File tree

1 file changed

+78
-9
lines changed

1 file changed

+78
-9
lines changed

test/test_ops.py

Lines changed: 78 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,33 @@ def is_differentiable_arg(arg):
8282
def normalize_op_input_output2(f, args, kwargs, output_process_fn_grad=None, requires_grad=True):
8383
flat_args, args_spec = tree_flatten(args)
8484
diff_argnums = tuple(i for i, arg in enumerate(flat_args) if diff_arg(arg, requires_grad=requires_grad))
85+
assert len(diff_argnums) > 0
86+
primals = tuple(flat_args[i] for i in diff_argnums)
87+
88+
@functools.wraps(f)
89+
def wrapped(*primals):
90+
_args = list(flat_args)
91+
for num, arg in zip(diff_argnums, primals):
92+
_args[num] = arg
93+
_args = tree_unflatten(_args, args_spec)
94+
result = f(*_args, **kwargs)
95+
if output_process_fn_grad is not None:
96+
result = output_process_fn_grad(result)
97+
if isinstance(result, tuple):
98+
# TODO: Remove the following hack for namedtuples
99+
result = tuple(result)
100+
result = tuple(r for r in result if torch.is_floating_point(r))
101+
assert len(result) > 0
102+
return result
103+
return wrapped, primals
104+
85105

106+
# TODO: consolidate with normalize_op_input_output2
107+
def normalize_op_input_output3(f, args, kwargs, sample_args, output_process_fn_grad=None):
108+
flat_args, args_spec = tree_flatten(args)
109+
flat_sample_args, _ = tree_flatten(sample_args)
110+
diff_argnums = tuple(i for i, (arg, sample) in enumerate(zip(flat_args, flat_sample_args))
111+
if diff_arg(sample, requires_grad=True))
86112
assert len(diff_argnums) > 0
87113
primals = tuple(flat_args[i] for i in diff_argnums)
88114

@@ -128,10 +154,43 @@ def ref_jvp(f, primals, tangents):
128154
primals_out, tangents_out = zip(*(fwAD.unpack_dual(d) for d in result_duals))
129155
return tree_unflatten(primals_out, spec), tree_unflatten(tangents_out, spec)
130156

131-
# Returns a new function g(*args, *cotangents) that computes vjps and
132-
# sample (*args, *cotangents)
133157

158+
def get_sample_cotangents(f, sample):
159+
fn, primals = normalize_op_input_output(f, sample)
160+
output = fn(*primals)
161+
if isinstance(output, tuple):
162+
# TODO: Remove the following hack for torch.return_types
163+
output = tuple(output)
164+
return tree_map(torch.randn_like, output)
165+
166+
167+
# returns a new function g(*args, *cotangents)
168+
# that computes vjps and (*args, cotangents)
169+
def get_vjp_fn_and_args_with_cotangents(f, sample, cotangents):
170+
args = tuple([sample.input] + list(sample.args))
171+
kwargs = sample.kwargs
172+
flat_args, args_spec = tree_flatten(args)
173+
flat_cotangents, cotangents_spec = tree_flatten(cotangents)
174+
175+
@functools.wraps(f)
176+
def wrapped(*args):
177+
assert len(args) == len(flat_args) + len(flat_cotangents)
178+
actual_args = args[:len(flat_args)]
179+
cotangents = args[len(flat_args):]
180+
actual_args = tree_unflatten(actual_args, args_spec)
181+
cotangents = tree_unflatten(cotangents, cotangents_spec)
182+
183+
fn, primals = normalize_op_input_output3(f, actual_args, kwargs,
184+
flat_args,
185+
sample.output_process_fn_grad)
186+
_, vjp_fn = vjp(fn, *primals)
187+
return vjp_fn(cotangents)
188+
189+
return wrapped, tuple(flat_args + flat_cotangents)
134190

191+
192+
# Returns a new function g(*args, *cotangents) that computes vjps and
193+
# sample (*args, *cotangents)
135194
def get_vjpfull_variant(f, sample):
136195
fn, primals = normalize_op_input_output(f, sample)
137196
result = fn(*primals)
@@ -438,6 +497,10 @@ def vjp_of_vjp(*args_and_cotangents):
438497
get_fallback_and_vmap_exhaustive(vjp_of_vjp, args_and_cotangents, {}):
439498
self.assertEqual(loop_out, batched_out, atol=1e-4, rtol=1e-4)
440499
vmapvjp_fail = vjp_fail.union({
500+
# The following are not bugs and are expected behavior
501+
xfail('fill_'), # Not possible, wontfix
502+
xfail('masked_select'), # Not possible due to dynamic shapes
503+
441504
# All of the following are bugs and need to be fixed
442505
xfail('diag_embed'),
443506
xfail('eig'),
@@ -447,7 +510,6 @@ def vjp_of_vjp(*args_and_cotangents):
447510
xfail('fft.rfft'),
448511
xfail('fft.rfft'),
449512
xfail('fft.rfftn'),
450-
xfail('cdist'),
451513
xfail('fmax'),
452514
xfail('fmin'),
453515
xfail('index_copy'),
@@ -487,6 +549,9 @@ def vjp_of_vjp(*args_and_cotangents):
487549
xfail('nn.functional.fractional_max_pool3d'),
488550
xfail('as_strided'),
489551
xfail('nn.functional.fractional_max_pool2d'),
552+
xfail('__getitem__'),
553+
xfail('index_put'),
554+
xfail('lu_solve'),
490555
})
491556

492557
@ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,))
@@ -504,7 +569,8 @@ def test_vmapvjp(self, device, dtype, op):
504569
return
505570

506571
for sample in samples:
507-
fn, args = get_vjpfull_variant(op, sample)
572+
cotangents = get_sample_cotangents(op, sample)
573+
fn, args = get_vjp_fn_and_args_with_cotangents(op, sample, cotangents)
508574
for loop_out, batched_out in get_fallback_and_vmap_exhaustive(fn, args, {}):
509575
self.assertEqual(loop_out, batched_out, atol=1e-4, rtol=1e-4)
510576

@@ -796,13 +862,17 @@ def test_vmapvjp_has_batch_rule(self, device, dtype, op):
796862

797863
def test():
798864
for sample in samples:
799-
fn, args = get_vjpfull_variant(op, sample)
800-
for _ in get_fallback_and_vmap_exhaustive(fn, args, {}, compute_loop_out=False):
865+
cotangents = get_sample_cotangents(op, sample)
866+
fn, args = get_vjp_fn_and_args_with_cotangents(op, sample, cotangents)
867+
for loop_out, batched_out in get_fallback_and_vmap_exhaustive(
868+
fn, args, {}, compute_loop_out=False):
801869
pass
802870
for a_op in op.aliases:
803-
fn, args = get_vjpfull_variant(a_op, sample)
804-
for _ in get_fallback_and_vmap_exhaustive(fn, args, {}, compute_loop_out=False):
871+
fn, args = get_vjp_fn_and_args_with_cotangents(a_op, sample, cotangents)
872+
for loop_out, batched_out in get_fallback_and_vmap_exhaustive(
873+
fn, args, {}, compute_loop_out=False):
805874
pass
875+
806876
check_vmap_fallback(self, test, op, dry_run=False)
807877

808878
@ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,))
@@ -823,7 +893,6 @@ def test():
823893
xfail('linalg.multi_dot'),
824894
xfail('vstack'),
825895
xfail('nn.functional.batch_norm'),
826-
xfail('cdist'),
827896
xfail('lu_solve'),
828897
xfail('lu_unpack'),
829898
xfail('matrix_exp'),

0 commit comments

Comments
 (0)