@@ -53,6 +53,23 @@ Please rewrite ``f`` to return ``intermediate``:
53
53
54
54
grad_x, intermediate = grad(f, has_aux=True)(x)
55
55
56
+ torch.autograd APIs
57
+ -------------------
58
+
59
+ If you are trying to use a ``torch.autograd `` API like ``torch.autograd.grad ``
60
+ or ``torch.autograd.backward `` inside of a function being transformed by
61
+ :func: `vmap ` or one of functorch's AD transforms (:func: `vjp `, :func: `jvp `,
62
+ :func: `jacrev `, :func: `jacfwd `), the transform may not be able to transform over it.
63
+ If it is unable to do so, you'll receive an error message.
64
+
65
+ This is a fundamental design limitation in how PyTorch's AD support is implemented
66
+ and the reason why we designed the functorch library. Please instead use the functorch
67
+ equivalents of the ``torch.autograd `` APIs:
68
+ - ``torch.autograd.grad ``, ``Tensor.backward `` -> ``functorch.vjp `` or ``functorch.grad ``
69
+ - ``torch.autograd.functional.jvp `` -> ``functorch.jvp ``
70
+ - ``torch.autograd.functional.jacobian `` -> ``functorch.jacrev `` or ``functorch.jacfwd ``
71
+ - ``torch.autograd.functional.hessian `` -> ``functorch.hessian ``
72
+
56
73
vmap limitations
57
74
----------------
58
75
@@ -144,6 +161,14 @@ elements (or more):
144
161
vmap(f, in_dims=(0, 0))(x, y)
145
162
assert torch.allclose(x, expected)
146
163
164
+ Mutation: out= PyTorch Operations
165
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
166
+ :func: `vmap ` doesn't support the ``out= `` keyword argument in PyTorch operations.
167
+ It will error out gracefully if it encounters that in your code.
168
+
169
+ This is not a fundamental limitation; we could theoretically support this in the
170
+ future but we have chosen not to for now.
171
+
147
172
Data-dependent Python control flow
148
173
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
149
174
We don't yet support ``vmap `` over data-dependent control flow. Data-dependent
@@ -180,6 +205,46 @@ using special control flow operators (e.g. ``jax.lax.cond``, ``jax.lax.while_loo
180
205
We're investigating adding equivalents of those to functorch
181
206
(open an issue on `GitHub <https://github.com/pytorch/functorch >`_ to voice your support!).
182
207
208
+ Data-dependent operations (.item())
209
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
210
+ We do not (and will not) support vmap over a user-defined function that calls
211
+ ``.item() `` on a Tensor. For example, the following will raise an error message:
212
+
213
+ ::
214
+
215
+ def f(x):
216
+ return x.item()
217
+
218
+ x = torch.randn(3)
219
+ vmap(f)(x)
220
+
221
+ Please try to rewrite your code to not use ``.item() `` calls.
222
+
223
+ You may also encounter an error message about using ``.item() `` but you might
224
+ not have used it. In those cases, it is possible that PyTorch internally is
225
+ calling ``.item() `` -- please file an issue on GitHub and we'll fix
226
+ PyTorch internals.
227
+
228
+ Dynamic shape operations (nonzero and friends)
229
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
230
+ ``vmap(f) `` requires that ``f `` applied to every "example" in your input
231
+ returns a Tensor with the same shape. Operations such as ``torch.nonzero ``,
232
+ ``torch.is_nonzero `` are not supported and will error as a result.
233
+
234
+ To see why, consider the following example:
235
+
236
+ ::
237
+
238
+ xs = torch.tensor([[0, 1, 2], [0, 0, 3]])
239
+ vmap(torch.nonzero)(xs)
240
+
241
+ ``torch.nonzero(xs[0]) `` returns a Tensor of shape 2;
242
+ but ``torch.nonzero(xs[1]) `` returns a Tensor of shape 1.
243
+ We are unable to construct a single Tensor as an output;
244
+ the output would need to be a ragged Tensor (and PyTorch does not yet have
245
+ the concept of a ragged Tensor).
246
+
247
+
183
248
Randomness
184
249
----------
185
250
The user's intention when calling a random operation can be unclear. Specifically, some users may want
0 commit comments