Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit d3477c1

Browse files
authored
Add has_aux=False to vjp, jacrev (#335)
Mitigates #333 but we should still consider whether or not we want a jacrev_and_value API. Test Plan: - New tests
1 parent eccb280 commit d3477c1

File tree

2 files changed

+92
-10
lines changed

2 files changed

+92
-10
lines changed

functorch/_src/eager_transforms.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def _autograd_grad(outputs, inputs, grad_outputs=None, retain_graph=False, creat
131131

132132

133133
# How do we increment and decrement the nesting? I don't think we can.
134-
def vjp(f: Callable, *primals):
134+
def vjp(f: Callable, *primals, has_aux=False):
135135
"""
136136
Standing for the vector-Jacobian product, returns a tuple containing the
137137
results of :attr:`f` applied to :attr:`primals` and a function that, when
@@ -144,13 +144,19 @@ def vjp(f: Callable, *primals):
144144
primals (Tensors): Positional arguments to :attr:`f` that must all be
145145
Tensors. The returned function will also be computing the
146146
derivative with respect to these arguments
147+
has_aux (bool): Flag indicating that :attr:`f` returns a
148+
``(output, aux)`` tuple where the first element is the output of
149+
the function to be differentiated and the second element is
150+
other auxiliary objects that will not be differentiated.
151+
Default: False.
147152
148153
Returns:
149-
Returns a tuple containing the output of :attr:`f` applied to
150-
:attr:`primals` and a function that computes the vjp of :attr:`f` with
151-
respect to all :attr:`primals` using the cotangents passed to the
152-
returned function. The returned function will return a tuple of each
153-
VJP
154+
Returns a ``(output, vjp_fn)`` tuple containing the output of :attr:`f`
155+
applied to :attr:`primals` and a function that computes the vjp of
156+
:attr:`f` with respect to all :attr:`primals` using the cotangents passed
157+
to the returned function. If ``has_aux is True``, then instead returns a
158+
``(output, vjp_fn, aux)`` tuple.
159+
The returned ``vjp_fn`` function will return a tuple of each VJP.
154160
155161
When used in simple cases, :func:`vjp` behaves the same as :func:`grad`
156162
@@ -228,6 +234,10 @@ def vjp(f: Callable, *primals):
228234
diff_primals = _create_differentiable(primals, level)
229235
primals_out = f(*diff_primals)
230236

237+
if has_aux:
238+
primals_out, aux = primals_out
239+
aux = _undo_create_differentiable(aux, level)
240+
231241
results = _undo_create_differentiable(primals_out, level)
232242
flat_diff_primals, primals_spec = tree_flatten(diff_primals)
233243
flat_primals_out, primals_out_spec = tree_flatten(primals_out)
@@ -257,13 +267,16 @@ def wrapper(cotangents, retain_graph=True, create_graph=None):
257267
finally:
258268
_grad_decrement_nesting()
259269

260-
return results, wrapper
270+
if has_aux:
271+
return results, wrapper, aux
272+
else:
273+
return results, wrapper
261274

262275
def _safe_zero_index(x):
263276
assert len(x) == 1
264277
return x[0]
265278

266-
def jacrev(f: Callable, argnums: Union[int, Tuple[int]] = 0):
279+
def jacrev(f: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False):
267280
"""
268281
Computes the Jacobian of :attr:`f` with respect to the arg(s) at index
269282
:attr:`argnum` using reverse mode autodiff
@@ -274,11 +287,18 @@ def jacrev(f: Callable, argnums: Union[int, Tuple[int]] = 0):
274287
argnums (int or Tuple[int]): Optional, integer or tuple of integers,
275288
saying which arguments to get the Jacobian with respect to.
276289
Default: 0.
290+
has_aux (bool): Flag indicating that :attr:`f` returns a
291+
``(output, aux)`` tuple where the first element is the output of
292+
the function to be differentiated and the second element is
293+
auxiliary objects that will not be differentiated.
294+
Default: False.
277295
278296
Returns:
279297
Returns a function that takes in the same inputs as :attr:`f` and
280298
returns the Jacobian of :attr:`f` with respect to the arg(s) at
281-
:attr:`argnums`
299+
:attr:`argnums`. If ``has_aux is True``, then the returned function
300+
instead returns a ``(jacobian, aux)`` tuple where ``jacobian``
301+
is the Jacobian and ``aux`` is auxiliary objects returned by ``f``.
282302
283303
A basic usage with a pointwise, unary operation will give a diagonal array
284304
as the Jacobian
@@ -358,7 +378,11 @@ def jacrev(f: Callable, argnums: Union[int, Tuple[int]] = 0):
358378
@wraps(f)
359379
def wrapper_fn(*args):
360380
f_wrapper, primals = _argnums_partial(f, args, argnums)
361-
output, vjp_fn = vjp(f_wrapper, *primals)
381+
vjp_out = vjp(f_wrapper, *primals, has_aux=has_aux)
382+
if has_aux:
383+
output, vjp_fn, aux = vjp_out
384+
else:
385+
output, vjp_fn = vjp_out
362386

363387
# See NOTE: [Computing jacobian with vmap and vjp for multiple outputs]
364388
flat_output, output_spec = tree_flatten(output)
@@ -409,6 +433,8 @@ def wrapper_fn(*args):
409433
flat_output_input = tuple(_safe_zero_index(flat_input)
410434
for flat_input in flat_output_input)
411435
output_input = tree_unflatten(flat_output_input, output_spec)
436+
if has_aux:
437+
return output_input, aux
412438
return output_input
413439
return wrapper_fn
414440

test/test_eager_transforms.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,37 @@ def f(x):
603603
with self.assertRaisesRegex(RuntimeError, 'Expected pytree structure'):
604604
result, = vjp_fn(((v1, (v2, v3)),))
605605

606+
def test_vjp_aux_tensor(self, device):
607+
def f(x):
608+
y = x.sin()
609+
return y, x.cos()
610+
611+
x = torch.randn(3, device=device)
612+
613+
out, vjp_fn, aux = vjp(f, x, has_aux=True)
614+
self.assertEqual(aux, x.cos())
615+
self.assertEqual(out, x.sin())
616+
617+
v = torch.randn(3, device=device)
618+
grad_x, = vjp_fn(v)
619+
self.assertEqual(grad_x, v * x.cos())
620+
621+
def test_vjp_aux_pytree(self, device):
622+
def f(x):
623+
y = x.sin()
624+
return y, {'a': x.cos(), 'b': [x.tan()]}
625+
626+
x = torch.randn(3, device=device)
627+
628+
out, vjp_fn, aux = vjp(f, x, has_aux=True)
629+
expected_out, expected_aux = f(x)
630+
self.assertEqual(out, expected_out)
631+
self.assertEqual(aux, expected_aux)
632+
633+
v = torch.randn(3, device=device)
634+
grad_x, = vjp_fn(v)
635+
self.assertEqual(grad_x, v * x.cos())
636+
606637
def test_functional_init(self, device):
607638
class MLPClassifier(nn.Module):
608639
def __init__(self, hidden_dim=32, n_classes=2):
@@ -1112,6 +1143,31 @@ def f(x):
11121143
self.assertEqual(result.dim(), 2)
11131144
self.assertEqual(result, x.new_ones(1, 1))
11141145

1146+
@FIXME_jacrev_only
1147+
def test_aux_tensor(self, device, jacapi):
1148+
def f(x):
1149+
y = x.clone()
1150+
return y, y.cos()
1151+
1152+
x = torch.randn(3, device=device)
1153+
result, aux = jacapi(f, has_aux=True)(x)
1154+
1155+
self.assertEqual(result, torch.eye(3, 3, device=device))
1156+
self.assertEqual(aux, x.cos())
1157+
1158+
@FIXME_jacrev_only
1159+
def test_aux_pytree(self, device, jacapi):
1160+
def f(x):
1161+
y = x.clone()
1162+
return y, {'a': y.cos(), 'b': [y.tan()]}
1163+
1164+
x = torch.randn(3, device=device)
1165+
result, aux = jacapi(f, has_aux=True)(x)
1166+
1167+
self.assertEqual(result, torch.eye(3, 3, device=device))
1168+
expected_aux = {'a': x.cos(), 'b': [x.tan()]}
1169+
self.assertEqual(aux, expected_aux)
1170+
11151171
@FIXME_jacrev_only
11161172
def test_multiple_inputs_outputs_pytree(self, device, jacapi):
11171173
def f(a, b, c):

0 commit comments

Comments
 (0)