@@ -139,29 +139,29 @@ def _autograd_grad(outputs, inputs, grad_outputs=None, retain_graph=False, creat
139
139
140
140
141
141
# How do we increment and decrement the nesting? I don't think we can.
142
- def vjp (f : Callable , * primals , has_aux = False ):
142
+ def vjp (func : Callable , * primals , has_aux = False ):
143
143
"""
144
144
Standing for the vector-Jacobian product, returns a tuple containing the
145
- results of :attr:`f ` applied to :attr:`primals` and a function that, when
146
- given ``cotangents``, computes the reverse-mode Jacobian of :attr:`f ` with
145
+ results of :attr:`func ` applied to :attr:`primals` and a function that, when
146
+ given ``cotangents``, computes the reverse-mode Jacobian of :attr:`func ` with
147
147
respect to :attr:`primals` times ``cotangents``.
148
148
149
149
Args:
150
- f (Callable): A Python function that takes one or more arguments. Must
150
+ func (Callable): A Python function that takes one or more arguments. Must
151
151
return one or more Tensors.
152
- primals (Tensors): Positional arguments to :attr:`f ` that must all be
152
+ primals (Tensors): Positional arguments to :attr:`func ` that must all be
153
153
Tensors. The returned function will also be computing the
154
154
derivative with respect to these arguments
155
- has_aux (bool): Flag indicating that :attr:`f ` returns a
155
+ has_aux (bool): Flag indicating that :attr:`func ` returns a
156
156
``(output, aux)`` tuple where the first element is the output of
157
157
the function to be differentiated and the second element is
158
158
other auxiliary objects that will not be differentiated.
159
159
Default: False.
160
160
161
161
Returns:
162
- Returns a ``(output, vjp_fn)`` tuple containing the output of :attr:`f `
162
+ Returns a ``(output, vjp_fn)`` tuple containing the output of :attr:`func `
163
163
applied to :attr:`primals` and a function that computes the vjp of
164
- :attr:`f ` with respect to all :attr:`primals` using the cotangents passed
164
+ :attr:`func ` with respect to all :attr:`primals` using the cotangents passed
165
165
to the returned function. If ``has_aux is True``, then instead returns a
166
166
``(output, vjp_fn, aux)`` tuple.
167
167
The returned ``vjp_fn`` function will return a tuple of each VJP.
@@ -240,7 +240,7 @@ def vjp(f: Callable, *primals, has_aux=False):
240
240
with torch .enable_grad ():
241
241
primals = _wrap_all_tensors (primals , level )
242
242
diff_primals = _create_differentiable (primals , level )
243
- primals_out = f (* diff_primals )
243
+ primals_out = func (* diff_primals )
244
244
245
245
if has_aux :
246
246
primals_out , aux = primals_out
@@ -286,9 +286,9 @@ def _safe_zero_index(x):
286
286
return x [0 ]
287
287
288
288
289
- def jacrev (f : Callable , argnums : Union [int , Tuple [int ]] = 0 , * , has_aux = False ):
289
+ def jacrev (func : Callable , argnums : Union [int , Tuple [int ]] = 0 , * , has_aux = False ):
290
290
"""
291
- Computes the Jacobian of :attr:`f ` with respect to the arg(s) at index
291
+ Computes the Jacobian of :attr:`func ` with respect to the arg(s) at index
292
292
:attr:`argnum` using reverse mode autodiff
293
293
294
294
Args:
@@ -297,18 +297,18 @@ def jacrev(f: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False):
297
297
argnums (int or Tuple[int]): Optional, integer or tuple of integers,
298
298
saying which arguments to get the Jacobian with respect to.
299
299
Default: 0.
300
- has_aux (bool): Flag indicating that :attr:`f ` returns a
300
+ has_aux (bool): Flag indicating that :attr:`func ` returns a
301
301
``(output, aux)`` tuple where the first element is the output of
302
302
the function to be differentiated and the second element is
303
303
auxiliary objects that will not be differentiated.
304
304
Default: False.
305
305
306
306
Returns:
307
- Returns a function that takes in the same inputs as :attr:`f ` and
308
- returns the Jacobian of :attr:`f ` with respect to the arg(s) at
307
+ Returns a function that takes in the same inputs as :attr:`func ` and
308
+ returns the Jacobian of :attr:`func ` with respect to the arg(s) at
309
309
:attr:`argnums`. If ``has_aux is True``, then the returned function
310
310
instead returns a ``(jacobian, aux)`` tuple where ``jacobian``
311
- is the Jacobian and ``aux`` is auxiliary objects returned by ``f ``.
311
+ is the Jacobian and ``aux`` is auxiliary objects returned by ``func ``.
312
312
313
313
A basic usage with a pointwise, unary operation will give a diagonal array
314
314
as the Jacobian
@@ -322,7 +322,7 @@ def jacrev(f: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False):
322
322
:func:`jacrev` can be composed with vmap to produce batched
323
323
Jacobians:
324
324
325
- >>> from functorch import jacrev
325
+ >>> from functorch import jacrev, vmap
326
326
>>> x = torch.randn(64, 5)
327
327
>>> jacobian = vmap(jacrev(torch.sin))(x)
328
328
>>> assert jacobian.shape == (64, 5, 5)
@@ -385,9 +385,9 @@ def jacrev(f: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False):
385
385
outer one. This is because ``jacrev`` is a "function transform": its result
386
386
should not depend on the result of a context manager outside of ``f``.
387
387
"""
388
- @wraps (f )
388
+ @wraps (func )
389
389
def wrapper_fn (* args ):
390
- f_wrapper , primals = _argnums_partial (f , args , argnums )
390
+ f_wrapper , primals = _argnums_partial (func , args , argnums )
391
391
vjp_out = vjp (f_wrapper , * primals , has_aux = has_aux )
392
392
if has_aux :
393
393
output , vjp_fn , aux = vjp_out
@@ -654,7 +654,52 @@ def safe_unpack_dual(dual, strict):
654
654
return primal , tangent
655
655
656
656
657
- def jvp (f , primals , tangents , * , strict = False ):
657
+ def jvp (func , primals , tangents , * , strict = False ):
658
+ """
659
+ Standing for the Jacobian-vector product, returns a tuple containing
660
+ the output of `func(*primals)` and the "Jacobian of ``func`` evaluated at
661
+ ``primals``" times ``tangents``. This is also known as forward-mode autodiff.
662
+
663
+ Args:
664
+ func (function): A Python function that takes one or more arguments,
665
+ one of which must be a Tensor, and returns one or more Tensors
666
+ primals (Tensors): Positional arguments to :attr:`func` that must all be
667
+ Tensors. The returned function will also be computing the
668
+ derivative with respect to these arguments
669
+ tangents (Tensors): The "vector" for which Jacobian-vector-product is
670
+ computed. Must be the same structure and sizes as the inputs to
671
+ ``func``.
672
+
673
+ Returns:
674
+ Returns a ``(output, jvp_out)`` tuple containing the output of ``func``
675
+ evaluated at ``primals`` and the Jacobian-vector product.
676
+
677
+ .. warning::
678
+ PyTorch's forward-mode AD coverage on operators is not very good at the
679
+ moment. You may see this API error out with "forward-mode AD not
680
+ implemented for operator X". If so, please file us a bug report and we
681
+ will prioritize it.
682
+
683
+ jvp is useful when you wish to compute gradients of a function R^1 -> R^N
684
+
685
+ >>> from functorch import jvp
686
+ >>> x = torch.randn([])
687
+ >>> f = lambda x: x * torch.tensor([1., 2., 3])
688
+ >>> value, grad = jvp(f, (x,), (torch.tensor(1.),))
689
+ >>> assert torch.allclose(value, f(x))
690
+ >>> assert torch.allclose(grad, torch.tensor([1., 2, 3]))
691
+
692
+ :func:`jvp` can support functions with multiple inputs by passing in the
693
+ tangents for each of the inputs
694
+
695
+ >>> from functorch import jvp
696
+ >>> x = torch.randn(5)
697
+ >>> y = torch.randn(5)
698
+ >>> f = lambda x, y: (x * y)
699
+ >>> _, output = jvp(f, (x, y), (torch.ones(5), torch.ones(5)))
700
+ >>> assert torch.allclose(output, x + y)
701
+
702
+ """
658
703
if not isinstance (primals , tuple ):
659
704
raise RuntimeError (
660
705
f'{ jvp_str } : Expected primals to be a tuple. '
@@ -683,7 +728,7 @@ def jvp(f, primals, tangents, *, strict=False):
683
728
flat_duals = tuple (fwAD .make_dual (p , t )
684
729
for p , t in zip (flat_primals , flat_tangents ))
685
730
duals = tree_unflatten (flat_duals , primals_spec )
686
- result_duals = f (* duals )
731
+ result_duals = func (* duals )
687
732
assert_output_is_tensor_or_tensors (result_duals , jvp_str )
688
733
result_duals , spec = tree_flatten (result_duals )
689
734
@@ -706,9 +751,87 @@ def safe_unflatten(tensor, dim, shape):
706
751
return tensor .unflatten (dim , shape )
707
752
708
753
709
- def jacfwd (f , argnums = 0 ):
754
+ def jacfwd (func , argnums = 0 ):
755
+ """
756
+ Computes the Jacobian of :attr:`func` with respect to the arg(s) at index
757
+ :attr:`argnum` using forward-mode autodiff
758
+
759
+ Args:
760
+ func (function): A Python function that takes one or more arguments,
761
+ one of which must be a Tensor, and returns one or more Tensors
762
+ argnums (int or Tuple[int]): Optional, integer or tuple of integers,
763
+ saying which arguments to get the Jacobian with respect to.
764
+ Default: 0.
765
+
766
+ Returns:
767
+ Returns a function that takes in the same inputs as :attr:`func` and
768
+ returns the Jacobian of :attr:`func` with respect to the arg(s) at
769
+ :attr:`argnums`.
770
+
771
+ .. warning::
772
+ PyTorch's forward-mode AD coverage on operators is not very good at the
773
+ moment. You may see this API error out with "forward-mode AD not
774
+ implemented for operator X". If so, please file us a bug report and we
775
+ will prioritize it.
776
+
777
+ A basic usage with a pointwise, unary operation will give a diagonal array
778
+ as the Jacobian
779
+
780
+ >>> from functorch import jacfwd
781
+ >>> x = torch.randn(5)
782
+ >>> jacobian = jacfwd(torch.sin)(x)
783
+ >>> expected = torch.diag(torch.cos(x))
784
+ >>> assert torch.allclose(jacobian, expected)
785
+
786
+ :func:`jacfwd` can be composed with vmap to produce batched
787
+ Jacobians:
788
+
789
+ >>> from functorch import jacfwd, vmap
790
+ >>> x = torch.randn(64, 5)
791
+ >>> jacobian = vmap(jacfwd(torch.sin))(x)
792
+ >>> assert jacobian.shape == (64, 5, 5)
793
+
794
+ Additionally, :func:`jacrev` can be composed with itself or :func:`jacrev`
795
+ to produce Hessians
796
+
797
+ >>> from functorch import jacfwd, jacrev
798
+ >>> def f(x):
799
+ >>> return x.sin().sum()
800
+ >>>
801
+ >>> x = torch.randn(5)
802
+ >>> hessian = jacfwd(jacrev(f))(x)
803
+ >>> assert torch.allclose(hessian, torch.diag(-x.sin()))
804
+
805
+ By default, :func:`jacfwd` computes the Jacobian with respect to the first
806
+ input. However, it can compute the Jacboian with respect to a different
807
+ argument by using :attr:`argnums`:
808
+
809
+ >>> from functorch import jacfwd
810
+ >>> def f(x, y):
811
+ >>> return x + y ** 2
812
+ >>>
813
+ >>> x, y = torch.randn(5), torch.randn(5)
814
+ >>> jacobian = jacfwd(f, argnums=1)(x, y)
815
+ >>> expected = torch.diag(2 * y)
816
+ >>> assert torch.allclose(jacobian, expected)
817
+
818
+ Additionally, passing a tuple to :attr:`argnums` will compute the Jacobian
819
+ with respect to multiple arguments
820
+
821
+ >>> from functorch import jacfwd
822
+ >>> def f(x, y):
823
+ >>> return x + y ** 2
824
+ >>>
825
+ >>> x, y = torch.randn(5), torch.randn(5)
826
+ >>> jacobian = jacfwd(f, argnums=(0, 1))(x, y)
827
+ >>> expectedX = torch.diag(torch.ones_like(x))
828
+ >>> expectedY = torch.diag(2 * y)
829
+ >>> assert torch.allclose(jacobian[0], expectedX)
830
+ >>> assert torch.allclose(jacobian[1], expectedY)
831
+
832
+ """
710
833
def wrapper_fn (* args ):
711
- f_wrapper , primals = _argnums_partial (f , args , argnums )
834
+ f_wrapper , primals = _argnums_partial (func , args , argnums )
712
835
flat_primals , primals_spec = tree_flatten (primals )
713
836
flat_primals_numels = tuple (p .numel () for p in flat_primals )
714
837
flat_basis = _construct_standard_basis_for (flat_primals , flat_primals_numels )
@@ -738,8 +861,46 @@ def push_jvp(basis):
738
861
return wrapper_fn
739
862
740
863
741
- def hessian (f , argnums = 0 ):
742
- return jacfwd (jacrev (f , argnums ), argnums )
864
+ def hessian (func , argnums = 0 ):
865
+ """
866
+ Computes the Hessian of :attr:`func` with respect to the arg(s) at index
867
+ :attr:`argnum` via a forward-over-reverse strategy.
868
+
869
+ The forward-over-reverse strategy (composing ``jacfwd(jacrev(func))``) is
870
+ a good default for good performance. It is possible to compute Hessians
871
+ through other compositions of :func:`jacfwd` and :func:`jacrev` like
872
+ ``jacfwd(jacfwd(func))`` or ``jacrev(jacrev(func))``.
873
+
874
+ Args:
875
+ func (function): A Python function that takes one or more arguments,
876
+ one of which must be a Tensor, and returns one or more Tensors
877
+ argnums (int or Tuple[int]): Optional, integer or tuple of integers,
878
+ saying which arguments to get the Hessian with respect to.
879
+ Default: 0.
880
+
881
+ Returns:
882
+ Returns a function that takes in the same inputs as :attr:`func` and
883
+ returns the Hessian of :attr:`func` with respect to the arg(s) at
884
+ :attr:`argnums`.
885
+
886
+ .. warning::
887
+ PyTorch's forward-mode AD coverage on operators is not very good at the
888
+ moment. You may see this API error out with "forward-mode AD not
889
+ implemented for operator X". If so, please file us a bug report and we
890
+ will prioritize it.
891
+
892
+ A basic usage with a R^N -> R^1 function gives a N x N Hessian:
893
+
894
+ >>> from functorch import hessian
895
+ >>> def f(x):
896
+ >>> return x.sin().sum()
897
+ >>>
898
+ >>> x = torch.randn(5)
899
+ >>> hess = jacfwd(jacrev(f))(x)
900
+ >>> assert torch.allclose(hess, torch.diag(-x.sin()))
901
+
902
+ """
903
+ return jacfwd (jacrev (func , argnums ), argnums )
743
904
744
905
745
906
def grad_and_value (func : Callable , argnums : argnums_t = 0 , has_aux : bool = False ) -> Callable :
0 commit comments