Skip to content

Commit 7fa79f9

Browse files
authored
More docs for jacfwd, hessian, jvp (#362)
Also moved them out of the experimental namespace but added warning on coverage.
1 parent 7a37f6b commit 7fa79f9

File tree

3 files changed

+189
-26
lines changed

3 files changed

+189
-26
lines changed

functorch/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717

1818
# functorch transforms
1919
from ._src.vmap import vmap
20-
from ._src.eager_transforms import grad, grad_and_value, vjp, jacrev
20+
from ._src.eager_transforms import (
21+
grad, grad_and_value, vjp, jacrev, jvp, jacfwd, hessian,
22+
)
2123
from ._src.python_key import make_fx
2224

2325
# utilities. Maybe these should go in their own namespace in the future?

functorch/_src/eager_transforms.py

Lines changed: 185 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -139,29 +139,29 @@ def _autograd_grad(outputs, inputs, grad_outputs=None, retain_graph=False, creat
139139

140140

141141
# How do we increment and decrement the nesting? I don't think we can.
142-
def vjp(f: Callable, *primals, has_aux=False):
142+
def vjp(func: Callable, *primals, has_aux=False):
143143
"""
144144
Standing for the vector-Jacobian product, returns a tuple containing the
145-
results of :attr:`f` applied to :attr:`primals` and a function that, when
146-
given ``cotangents``, computes the reverse-mode Jacobian of :attr:`f` with
145+
results of :attr:`func` applied to :attr:`primals` and a function that, when
146+
given ``cotangents``, computes the reverse-mode Jacobian of :attr:`func` with
147147
respect to :attr:`primals` times ``cotangents``.
148148
149149
Args:
150-
f (Callable): A Python function that takes one or more arguments. Must
150+
func (Callable): A Python function that takes one or more arguments. Must
151151
return one or more Tensors.
152-
primals (Tensors): Positional arguments to :attr:`f` that must all be
152+
primals (Tensors): Positional arguments to :attr:`func` that must all be
153153
Tensors. The returned function will also be computing the
154154
derivative with respect to these arguments
155-
has_aux (bool): Flag indicating that :attr:`f` returns a
155+
has_aux (bool): Flag indicating that :attr:`func` returns a
156156
``(output, aux)`` tuple where the first element is the output of
157157
the function to be differentiated and the second element is
158158
other auxiliary objects that will not be differentiated.
159159
Default: False.
160160
161161
Returns:
162-
Returns a ``(output, vjp_fn)`` tuple containing the output of :attr:`f`
162+
Returns a ``(output, vjp_fn)`` tuple containing the output of :attr:`func`
163163
applied to :attr:`primals` and a function that computes the vjp of
164-
:attr:`f` with respect to all :attr:`primals` using the cotangents passed
164+
:attr:`func` with respect to all :attr:`primals` using the cotangents passed
165165
to the returned function. If ``has_aux is True``, then instead returns a
166166
``(output, vjp_fn, aux)`` tuple.
167167
The returned ``vjp_fn`` function will return a tuple of each VJP.
@@ -240,7 +240,7 @@ def vjp(f: Callable, *primals, has_aux=False):
240240
with torch.enable_grad():
241241
primals = _wrap_all_tensors(primals, level)
242242
diff_primals = _create_differentiable(primals, level)
243-
primals_out = f(*diff_primals)
243+
primals_out = func(*diff_primals)
244244

245245
if has_aux:
246246
primals_out, aux = primals_out
@@ -286,9 +286,9 @@ def _safe_zero_index(x):
286286
return x[0]
287287

288288

289-
def jacrev(f: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False):
289+
def jacrev(func: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False):
290290
"""
291-
Computes the Jacobian of :attr:`f` with respect to the arg(s) at index
291+
Computes the Jacobian of :attr:`func` with respect to the arg(s) at index
292292
:attr:`argnum` using reverse mode autodiff
293293
294294
Args:
@@ -297,18 +297,18 @@ def jacrev(f: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False):
297297
argnums (int or Tuple[int]): Optional, integer or tuple of integers,
298298
saying which arguments to get the Jacobian with respect to.
299299
Default: 0.
300-
has_aux (bool): Flag indicating that :attr:`f` returns a
300+
has_aux (bool): Flag indicating that :attr:`func` returns a
301301
``(output, aux)`` tuple where the first element is the output of
302302
the function to be differentiated and the second element is
303303
auxiliary objects that will not be differentiated.
304304
Default: False.
305305
306306
Returns:
307-
Returns a function that takes in the same inputs as :attr:`f` and
308-
returns the Jacobian of :attr:`f` with respect to the arg(s) at
307+
Returns a function that takes in the same inputs as :attr:`func` and
308+
returns the Jacobian of :attr:`func` with respect to the arg(s) at
309309
:attr:`argnums`. If ``has_aux is True``, then the returned function
310310
instead returns a ``(jacobian, aux)`` tuple where ``jacobian``
311-
is the Jacobian and ``aux`` is auxiliary objects returned by ``f``.
311+
is the Jacobian and ``aux`` is auxiliary objects returned by ``func``.
312312
313313
A basic usage with a pointwise, unary operation will give a diagonal array
314314
as the Jacobian
@@ -322,7 +322,7 @@ def jacrev(f: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False):
322322
:func:`jacrev` can be composed with vmap to produce batched
323323
Jacobians:
324324
325-
>>> from functorch import jacrev
325+
>>> from functorch import jacrev, vmap
326326
>>> x = torch.randn(64, 5)
327327
>>> jacobian = vmap(jacrev(torch.sin))(x)
328328
>>> assert jacobian.shape == (64, 5, 5)
@@ -385,9 +385,9 @@ def jacrev(f: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False):
385385
outer one. This is because ``jacrev`` is a "function transform": its result
386386
should not depend on the result of a context manager outside of ``f``.
387387
"""
388-
@wraps(f)
388+
@wraps(func)
389389
def wrapper_fn(*args):
390-
f_wrapper, primals = _argnums_partial(f, args, argnums)
390+
f_wrapper, primals = _argnums_partial(func, args, argnums)
391391
vjp_out = vjp(f_wrapper, *primals, has_aux=has_aux)
392392
if has_aux:
393393
output, vjp_fn, aux = vjp_out
@@ -654,7 +654,52 @@ def safe_unpack_dual(dual, strict):
654654
return primal, tangent
655655

656656

657-
def jvp(f, primals, tangents, *, strict=False):
657+
def jvp(func, primals, tangents, *, strict=False):
658+
"""
659+
Standing for the Jacobian-vector product, returns a tuple containing
660+
the output of `func(*primals)` and the "Jacobian of ``func`` evaluated at
661+
``primals``" times ``tangents``. This is also known as forward-mode autodiff.
662+
663+
Args:
664+
func (function): A Python function that takes one or more arguments,
665+
one of which must be a Tensor, and returns one or more Tensors
666+
primals (Tensors): Positional arguments to :attr:`func` that must all be
667+
Tensors. The returned function will also be computing the
668+
derivative with respect to these arguments
669+
tangents (Tensors): The "vector" for which Jacobian-vector-product is
670+
computed. Must be the same structure and sizes as the inputs to
671+
``func``.
672+
673+
Returns:
674+
Returns a ``(output, jvp_out)`` tuple containing the output of ``func``
675+
evaluated at ``primals`` and the Jacobian-vector product.
676+
677+
.. warning::
678+
PyTorch's forward-mode AD coverage on operators is not very good at the
679+
moment. You may see this API error out with "forward-mode AD not
680+
implemented for operator X". If so, please file us a bug report and we
681+
will prioritize it.
682+
683+
jvp is useful when you wish to compute gradients of a function R^1 -> R^N
684+
685+
>>> from functorch import jvp
686+
>>> x = torch.randn([])
687+
>>> f = lambda x: x * torch.tensor([1., 2., 3])
688+
>>> value, grad = jvp(f, (x,), (torch.tensor(1.),))
689+
>>> assert torch.allclose(value, f(x))
690+
>>> assert torch.allclose(grad, torch.tensor([1., 2, 3]))
691+
692+
:func:`jvp` can support functions with multiple inputs by passing in the
693+
tangents for each of the inputs
694+
695+
>>> from functorch import jvp
696+
>>> x = torch.randn(5)
697+
>>> y = torch.randn(5)
698+
>>> f = lambda x, y: (x * y)
699+
>>> _, output = jvp(f, (x, y), (torch.ones(5), torch.ones(5)))
700+
>>> assert torch.allclose(output, x + y)
701+
702+
"""
658703
if not isinstance(primals, tuple):
659704
raise RuntimeError(
660705
f'{jvp_str}: Expected primals to be a tuple. '
@@ -683,7 +728,7 @@ def jvp(f, primals, tangents, *, strict=False):
683728
flat_duals = tuple(fwAD.make_dual(p, t)
684729
for p, t in zip(flat_primals, flat_tangents))
685730
duals = tree_unflatten(flat_duals, primals_spec)
686-
result_duals = f(*duals)
731+
result_duals = func(*duals)
687732
assert_output_is_tensor_or_tensors(result_duals, jvp_str)
688733
result_duals, spec = tree_flatten(result_duals)
689734

@@ -706,9 +751,87 @@ def safe_unflatten(tensor, dim, shape):
706751
return tensor.unflatten(dim, shape)
707752

708753

709-
def jacfwd(f, argnums=0):
754+
def jacfwd(func, argnums=0):
755+
"""
756+
Computes the Jacobian of :attr:`func` with respect to the arg(s) at index
757+
:attr:`argnum` using forward-mode autodiff
758+
759+
Args:
760+
func (function): A Python function that takes one or more arguments,
761+
one of which must be a Tensor, and returns one or more Tensors
762+
argnums (int or Tuple[int]): Optional, integer or tuple of integers,
763+
saying which arguments to get the Jacobian with respect to.
764+
Default: 0.
765+
766+
Returns:
767+
Returns a function that takes in the same inputs as :attr:`func` and
768+
returns the Jacobian of :attr:`func` with respect to the arg(s) at
769+
:attr:`argnums`.
770+
771+
.. warning::
772+
PyTorch's forward-mode AD coverage on operators is not very good at the
773+
moment. You may see this API error out with "forward-mode AD not
774+
implemented for operator X". If so, please file us a bug report and we
775+
will prioritize it.
776+
777+
A basic usage with a pointwise, unary operation will give a diagonal array
778+
as the Jacobian
779+
780+
>>> from functorch import jacfwd
781+
>>> x = torch.randn(5)
782+
>>> jacobian = jacfwd(torch.sin)(x)
783+
>>> expected = torch.diag(torch.cos(x))
784+
>>> assert torch.allclose(jacobian, expected)
785+
786+
:func:`jacfwd` can be composed with vmap to produce batched
787+
Jacobians:
788+
789+
>>> from functorch import jacfwd, vmap
790+
>>> x = torch.randn(64, 5)
791+
>>> jacobian = vmap(jacfwd(torch.sin))(x)
792+
>>> assert jacobian.shape == (64, 5, 5)
793+
794+
Additionally, :func:`jacrev` can be composed with itself or :func:`jacrev`
795+
to produce Hessians
796+
797+
>>> from functorch import jacfwd, jacrev
798+
>>> def f(x):
799+
>>> return x.sin().sum()
800+
>>>
801+
>>> x = torch.randn(5)
802+
>>> hessian = jacfwd(jacrev(f))(x)
803+
>>> assert torch.allclose(hessian, torch.diag(-x.sin()))
804+
805+
By default, :func:`jacfwd` computes the Jacobian with respect to the first
806+
input. However, it can compute the Jacboian with respect to a different
807+
argument by using :attr:`argnums`:
808+
809+
>>> from functorch import jacfwd
810+
>>> def f(x, y):
811+
>>> return x + y ** 2
812+
>>>
813+
>>> x, y = torch.randn(5), torch.randn(5)
814+
>>> jacobian = jacfwd(f, argnums=1)(x, y)
815+
>>> expected = torch.diag(2 * y)
816+
>>> assert torch.allclose(jacobian, expected)
817+
818+
Additionally, passing a tuple to :attr:`argnums` will compute the Jacobian
819+
with respect to multiple arguments
820+
821+
>>> from functorch import jacfwd
822+
>>> def f(x, y):
823+
>>> return x + y ** 2
824+
>>>
825+
>>> x, y = torch.randn(5), torch.randn(5)
826+
>>> jacobian = jacfwd(f, argnums=(0, 1))(x, y)
827+
>>> expectedX = torch.diag(torch.ones_like(x))
828+
>>> expectedY = torch.diag(2 * y)
829+
>>> assert torch.allclose(jacobian[0], expectedX)
830+
>>> assert torch.allclose(jacobian[1], expectedY)
831+
832+
"""
710833
def wrapper_fn(*args):
711-
f_wrapper, primals = _argnums_partial(f, args, argnums)
834+
f_wrapper, primals = _argnums_partial(func, args, argnums)
712835
flat_primals, primals_spec = tree_flatten(primals)
713836
flat_primals_numels = tuple(p.numel() for p in flat_primals)
714837
flat_basis = _construct_standard_basis_for(flat_primals, flat_primals_numels)
@@ -738,8 +861,46 @@ def push_jvp(basis):
738861
return wrapper_fn
739862

740863

741-
def hessian(f, argnums=0):
742-
return jacfwd(jacrev(f, argnums), argnums)
864+
def hessian(func, argnums=0):
865+
"""
866+
Computes the Hessian of :attr:`func` with respect to the arg(s) at index
867+
:attr:`argnum` via a forward-over-reverse strategy.
868+
869+
The forward-over-reverse strategy (composing ``jacfwd(jacrev(func))``) is
870+
a good default for good performance. It is possible to compute Hessians
871+
through other compositions of :func:`jacfwd` and :func:`jacrev` like
872+
``jacfwd(jacfwd(func))`` or ``jacrev(jacrev(func))``.
873+
874+
Args:
875+
func (function): A Python function that takes one or more arguments,
876+
one of which must be a Tensor, and returns one or more Tensors
877+
argnums (int or Tuple[int]): Optional, integer or tuple of integers,
878+
saying which arguments to get the Hessian with respect to.
879+
Default: 0.
880+
881+
Returns:
882+
Returns a function that takes in the same inputs as :attr:`func` and
883+
returns the Hessian of :attr:`func` with respect to the arg(s) at
884+
:attr:`argnums`.
885+
886+
.. warning::
887+
PyTorch's forward-mode AD coverage on operators is not very good at the
888+
moment. You may see this API error out with "forward-mode AD not
889+
implemented for operator X". If so, please file us a bug report and we
890+
will prioritize it.
891+
892+
A basic usage with a R^N -> R^1 function gives a N x N Hessian:
893+
894+
>>> from functorch import hessian
895+
>>> def f(x):
896+
>>> return x.sin().sum()
897+
>>>
898+
>>> x = torch.randn(5)
899+
>>> hess = jacfwd(jacrev(f))(x)
900+
>>> assert torch.allclose(hess, torch.diag(-x.sin()))
901+
902+
"""
903+
return jacfwd(jacrev(func, argnums), argnums)
743904

744905

745906
def grad_and_value(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable:

test/test_pythonkey.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def f(x):
8080
new_inp = torch.randn(3)
8181
self.assertEqual(fx_f(new_inp), f(new_inp))
8282

83-
def test_make_fx_jvp(self, device):
83+
def test_make_fx_vjp(self, device):
8484
def f(x):
8585
return torch.sin(x).sum()
8686

0 commit comments

Comments
 (0)