Skip to content

Commit a3137b6

Browse files
authored
Beef up transform limitations doc (#879)
I want to be able to point someone at this page whenever we get asked about the limitations of vmap. Please let me know if there are things we're still missing from here
1 parent d768509 commit a3137b6

File tree

1 file changed

+65
-0
lines changed

1 file changed

+65
-0
lines changed

docs/source/ux_limitations.rst

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,23 @@ Please rewrite ``f`` to return ``intermediate``:
5353

5454
grad_x, intermediate = grad(f, has_aux=True)(x)
5555

56+
torch.autograd APIs
57+
-------------------
58+
59+
If you are trying to use a ``torch.autograd`` API like ``torch.autograd.grad``
60+
or ``torch.autograd.backward`` inside of a function being transformed by
61+
:func:`vmap` or one of functorch's AD transforms (:func:`vjp`, :func:`jvp`,
62+
:func:`jacrev`, :func:`jacfwd`), the transform may not be able to transform over it.
63+
If it is unable to do so, you'll receive an error message.
64+
65+
This is a fundamental design limitation in how PyTorch's AD support is implemented
66+
and the reason why we designed the functorch library. Please instead use the functorch
67+
equivalents of the ``torch.autograd`` APIs:
68+
- ``torch.autograd.grad``, ``Tensor.backward`` -> ``functorch.vjp`` or ``functorch.grad``
69+
- ``torch.autograd.functional.jvp`` -> ``functorch.jvp``
70+
- ``torch.autograd.functional.jacobian`` -> ``functorch.jacrev`` or ``functorch.jacfwd``
71+
- ``torch.autograd.functional.hessian`` -> ``functorch.hessian``
72+
5673
vmap limitations
5774
----------------
5875

@@ -144,6 +161,14 @@ elements (or more):
144161
vmap(f, in_dims=(0, 0))(x, y)
145162
assert torch.allclose(x, expected)
146163

164+
Mutation: out= PyTorch Operations
165+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
166+
:func:`vmap` doesn't support the ``out=`` keyword argument in PyTorch operations.
167+
It will error out gracefully if it encounters that in your code.
168+
169+
This is not a fundamental limitation; we could theoretically support this in the
170+
future but we have chosen not to for now.
171+
147172
Data-dependent Python control flow
148173
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
149174
We don't yet support ``vmap`` over data-dependent control flow. Data-dependent
@@ -180,6 +205,46 @@ using special control flow operators (e.g. ``jax.lax.cond``, ``jax.lax.while_loo
180205
We're investigating adding equivalents of those to functorch
181206
(open an issue on `GitHub <https://github.com/pytorch/functorch>`_ to voice your support!).
182207

208+
Data-dependent operations (.item())
209+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
210+
We do not (and will not) support vmap over a user-defined function that calls
211+
``.item()`` on a Tensor. For example, the following will raise an error message:
212+
213+
::
214+
215+
def f(x):
216+
return x.item()
217+
218+
x = torch.randn(3)
219+
vmap(f)(x)
220+
221+
Please try to rewrite your code to not use ``.item()`` calls.
222+
223+
You may also encounter an error message about using ``.item()`` but you might
224+
not have used it. In those cases, it is possible that PyTorch internally is
225+
calling ``.item()`` -- please file an issue on GitHub and we'll fix
226+
PyTorch internals.
227+
228+
Dynamic shape operations (nonzero and friends)
229+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
230+
``vmap(f)`` requires that ``f`` applied to every "example" in your input
231+
returns a Tensor with the same shape. Operations such as ``torch.nonzero``,
232+
``torch.is_nonzero`` are not supported and will error as a result.
233+
234+
To see why, consider the following example:
235+
236+
::
237+
238+
xs = torch.tensor([[0, 1, 2], [0, 0, 3]])
239+
vmap(torch.nonzero)(xs)
240+
241+
``torch.nonzero(xs[0])`` returns a Tensor of shape 2;
242+
but ``torch.nonzero(xs[1])`` returns a Tensor of shape 1.
243+
We are unable to construct a single Tensor as an output;
244+
the output would need to be a ragged Tensor (and PyTorch does not yet have
245+
the concept of a ragged Tensor).
246+
247+
183248
Randomness
184249
----------
185250
The user's intention when calling a random operation can be unclear. Specifically, some users may want

0 commit comments

Comments
 (0)