@@ -171,25 +171,16 @@ def _get_name(func: Callable):
171
171
# on BatchedTensors perform the batched operations that the user is asking for.
172
172
def vmap (func : Callable , in_dims : in_dims_t = 0 , out_dims : out_dims_t = 0 ) -> Callable :
173
173
"""
174
- vmap is the vectorizing map. Returns a new function that maps `func` over some
175
- dimension of the inputs. Semantically, vmap pushes the map into PyTorch
176
- operations called by `func`, effectively vectorizing those operations.
174
+ vmap is the vectorizing map; `vmap(func)` returns a new function that maps
175
+ `func` over some dimension of the inputs. Semantically, vmap pushes the map
176
+ into PyTorch operations called by `func`, effectively vectorizing those
177
+ operations.
177
178
178
179
vmap is useful for handling batch dimensions: one can write a function `func`
179
180
that runs on examples and then lift it to a function that can take batches of
180
181
examples with `vmap(func)`. vmap can also be used to compute batched
181
182
gradients when composed with autograd.
182
183
183
- .. warning::
184
- functorch.vmap is an experimental prototype that is subject to
185
- change and/or deletion. Please use at your own risk.
186
-
187
- .. note::
188
- If you're interested in using vmap for your use case, please
189
- `contact us! <https://github.com/pytorch/pytorch/issues/42368>`_
190
- We're interested in gathering feedback from early adopters to inform
191
- the design.
192
-
193
184
Args:
194
185
func (function): A Python function that takes one or more arguments.
195
186
Must return one or more Tensors.
0 commit comments