@@ -131,7 +131,7 @@ def _autograd_grad(outputs, inputs, grad_outputs=None, retain_graph=False, creat
131
131
132
132
133
133
# How do we increment and decrement the nesting? I don't think we can.
134
- def vjp (f : Callable , * primals ):
134
+ def vjp (f : Callable , * primals , has_aux = False ):
135
135
"""
136
136
Standing for the vector-Jacobian product, returns a tuple containing the
137
137
results of :attr:`f` applied to :attr:`primals` and a function that, when
@@ -144,13 +144,19 @@ def vjp(f: Callable, *primals):
144
144
primals (Tensors): Positional arguments to :attr:`f` that must all be
145
145
Tensors. The returned function will also be computing the
146
146
derivative with respect to these arguments
147
+ has_aux (bool): Flag indicating that :attr:`f` returns a
148
+ ``(output, aux)`` tuple where the first element is the output of
149
+ the function to be differentiated and the second element is
150
+ other auxiliary objects that will not be differentiated.
151
+ Default: False.
147
152
148
153
Returns:
149
- Returns a tuple containing the output of :attr:`f` applied to
150
- :attr:`primals` and a function that computes the vjp of :attr:`f` with
151
- respect to all :attr:`primals` using the cotangents passed to the
152
- returned function. The returned function will return a tuple of each
153
- VJP
154
+ Returns a ``(output, vjp_fn)`` tuple containing the output of :attr:`f`
155
+ applied to :attr:`primals` and a function that computes the vjp of
156
+ :attr:`f` with respect to all :attr:`primals` using the cotangents passed
157
+ to the returned function. If ``has_aux is True``, then instead returns a
158
+ ``(output, vjp_fn, aux)`` tuple.
159
+ The returned ``vjp_fn`` function will return a tuple of each VJP.
154
160
155
161
When used in simple cases, :func:`vjp` behaves the same as :func:`grad`
156
162
@@ -228,6 +234,10 @@ def vjp(f: Callable, *primals):
228
234
diff_primals = _create_differentiable (primals , level )
229
235
primals_out = f (* diff_primals )
230
236
237
+ if has_aux :
238
+ primals_out , aux = primals_out
239
+ aux = _undo_create_differentiable (aux , level )
240
+
231
241
results = _undo_create_differentiable (primals_out , level )
232
242
flat_diff_primals , primals_spec = tree_flatten (diff_primals )
233
243
flat_primals_out , primals_out_spec = tree_flatten (primals_out )
@@ -257,13 +267,16 @@ def wrapper(cotangents, retain_graph=True, create_graph=None):
257
267
finally :
258
268
_grad_decrement_nesting ()
259
269
260
- return results , wrapper
270
+ if has_aux :
271
+ return results , wrapper , aux
272
+ else :
273
+ return results , wrapper
261
274
262
275
def _safe_zero_index (x ):
263
276
assert len (x ) == 1
264
277
return x [0 ]
265
278
266
- def jacrev (f : Callable , argnums : Union [int , Tuple [int ]] = 0 ):
279
+ def jacrev (f : Callable , argnums : Union [int , Tuple [int ]] = 0 , * , has_aux = False ):
267
280
"""
268
281
Computes the Jacobian of :attr:`f` with respect to the arg(s) at index
269
282
:attr:`argnum` using reverse mode autodiff
@@ -274,11 +287,18 @@ def jacrev(f: Callable, argnums: Union[int, Tuple[int]] = 0):
274
287
argnums (int or Tuple[int]): Optional, integer or tuple of integers,
275
288
saying which arguments to get the Jacobian with respect to.
276
289
Default: 0.
290
+ has_aux (bool): Flag indicating that :attr:`f` returns a
291
+ ``(output, aux)`` tuple where the first element is the output of
292
+ the function to be differentiated and the second element is
293
+ auxiliary objects that will not be differentiated.
294
+ Default: False.
277
295
278
296
Returns:
279
297
Returns a function that takes in the same inputs as :attr:`f` and
280
298
returns the Jacobian of :attr:`f` with respect to the arg(s) at
281
- :attr:`argnums`
299
+ :attr:`argnums`. If ``has_aux is True``, then the returned function
300
+ instead returns a ``(jacobian, aux)`` tuple where ``jacobian``
301
+ is the Jacobian and ``aux`` is auxiliary objects returned by ``f``.
282
302
283
303
A basic usage with a pointwise, unary operation will give a diagonal array
284
304
as the Jacobian
@@ -358,7 +378,11 @@ def jacrev(f: Callable, argnums: Union[int, Tuple[int]] = 0):
358
378
@wraps (f )
359
379
def wrapper_fn (* args ):
360
380
f_wrapper , primals = _argnums_partial (f , args , argnums )
361
- output , vjp_fn = vjp (f_wrapper , * primals )
381
+ vjp_out = vjp (f_wrapper , * primals , has_aux = has_aux )
382
+ if has_aux :
383
+ output , vjp_fn , aux = vjp_out
384
+ else :
385
+ output , vjp_fn = vjp_out
362
386
363
387
# See NOTE: [Computing jacobian with vmap and vjp for multiple outputs]
364
388
flat_output , output_spec = tree_flatten (output )
@@ -409,6 +433,8 @@ def wrapper_fn(*args):
409
433
flat_output_input = tuple (_safe_zero_index (flat_input )
410
434
for flat_input in flat_output_input )
411
435
output_input = tree_unflatten (flat_output_input , output_spec )
436
+ if has_aux :
437
+ return output_input , aux
412
438
return output_input
413
439
return wrapper_fn
414
440
0 commit comments