Skip to content

Commit bcee114

Browse files
authored
Merge branch 'develop' into uarray-intro
2 parents 701676b + 165b46c commit bcee114

File tree

307 files changed

+23046
-45692
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

307 files changed

+23046
-45692
lines changed

.github/dependabot.yml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,22 @@ updates:
99
prefix-development: chore
1010
include: scope
1111
open-pull-requests-limit: 5
12+
groups:
13+
# All Nx updates must occur together. See:
14+
# https://nx.dev/recipes/tips-n-tricks/keep-nx-versions-in-sync
15+
nx:
16+
patterns:
17+
- "nx"
18+
- "@nrwl/*"
19+
- "@nx/*"
20+
update-types:
21+
- "patch"
22+
- "minor"
23+
ignore:
24+
# Do major Nx updates manually with `nx migrate`, not with Dependabot
25+
- dependency-name: "nx"
26+
update-types: ["version-update:semver-major"]
27+
- dependency-name: "@nrwl/*"
28+
update-types: ["version-update:semver-major"]
29+
- dependency-name: "@nx/*"
30+
update-types: ["version-update:semver-major"]

.github/pull_request_template.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Thank you for submitting a blog post! We want to make sure content produced for
1010

1111
## Non-text contents
1212

13+
- [ ] Blog post featured image is in PNG or JPEG format, **not** SVG.
1314
- [ ] All content is represented as text (for example, images need alt text and videos need captions or descriptive transcripts).
1415
- [ ] If there are emojis, there are not more than three in a row.
1516
- [ ] Don't use [flashing gifs or videos](https://www.w3.org/TR/UNDERSTANDING-WCAG20/seizure-does-not-violate.html).

.gitignore

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,9 @@ Thumbs.db
4444
.eslintcache
4545
.env
4646
.history
47+
48+
#rss.xml file
49+
rss.xml
50+
51+
.next/
52+
.nx/

README.md

Lines changed: 169 additions & 141 deletions
Large diffs are not rendered by default.
Lines changed: 327 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,327 @@
1+
---
2+
title: 'Review: `torch.func` Contribution'
3+
published: September 22, 2023
4+
author: kshiteej-kalambarkar
5+
description: '`torch.func` (previously known as `functorch`) is a PyTorch module
6+
designed to offer JAX-like transforms. Within this module, various higher-order
7+
functions, such as `grad`, `vmap`, and `vjp` are made accessible. These
8+
transforms help users to easily compute gradients for the parameters of their model
9+
or write batch-size agnostic code. The beauty of these transformations lies in
10+
their ability to compose with one another. Thanks to this, the process of
11+
calculating per-sample gradients becomes the straightforward application of
12+
`vmap(grad(model))`.'
13+
---
14+
15+
<base target="_blank" />
16+
17+
`torch.func` (previously known as `functorch`) is a PyTorch module designed to
18+
offer JAX-like transforms. Within this module, various higher-order
19+
functions, such as `grad`, `vmap`, and `vjp` are made accessible. These
20+
transforms help users to easily compute gradients for the parameters of their model
21+
or write batch-size agnostic code. The beauty of these transformations lies in
22+
their ability to compose with one another. Thanks to this, the process of
23+
calculating per-sample gradients becomes the straightforward application of
24+
`vmap(grad(model))`.
25+
26+
Here are a few of the tasks we had the opportunity to tackle:
27+
28+
## Adding Batching Rules for `vmap`
29+
30+
`vmap` is a transformation that accepts a function that operates on non-batched
31+
tensors and returns a new function that operates on batched tensors.
32+
When processing a batched input, an additional
33+
dimension, denoted by `in_dims`, is introduced to indicate which dimension to
34+
apply the function over. Conceptually, it emulates a `for` loop that iterates
35+
through all the data points and stacks the results. Importantly, it performs this
36+
operation efficiently by pushing the `for` loop into the PyTorch operations,
37+
allowing the batches to run in parallel.
38+
39+
Consider the following example:
40+
41+
```python
42+
import torch
43+
44+
# Written to handle only single sample.
45+
def my_simple_model(feature_vec, weight):
46+
return torch.dot(feature_vec, weight).relu()
47+
48+
batch_size = 4
49+
batched_inputs = torch.randn(batch_size, 3)
50+
weight = torch.randn(3)
51+
52+
# For Loop version
53+
expected = []
54+
for input in batched_inputs:
55+
expected.append(my_simple_model(input, weight))
56+
expected = torch.stack(expected)
57+
58+
# Vmap
59+
# `in_dims` specifies the dimension that should be mapped over.
60+
# In this case, we map only over 0-dim of `batched_inputs`.
61+
actual = torch.vmap(my_simple_model, in_dims=(0, None))(batched_inputs, weight)
62+
63+
# Verify that the results match.
64+
torch.testing.assert_close(expected, actual)
65+
```
66+
67+
To support `vmap` for PyTorch operators, we need to specify the batching rule
68+
i.e. how to map the given function over a batched input.
69+
A batching rule is essentially a function which takes one or multiple batched
70+
inputs and computes the batched operation. In the above example to support
71+
`vmap` for `my_simple_model`, we need to know the batching rule for `torch.dot`
72+
and `torch.relu` to be able to vectorize our model.
73+
PyTorch has more than [2000 operators](https://dev-discuss.pytorch.org/t/where-do-the-2000-pytorch-operators-come-from-more-than-you-wanted-to-know/373)
74+
and we need to have coverage for all of them
75+
to support `vmap`. That being said, there is a `for`-loop
76+
fallback in case an operator is not supported so as not to crash the code.
77+
78+
From the point of view of adding batching rules, PyTorch operators can be roughly
79+
categorized as primitive or composite.
80+
Primitive operators are the ones for which we specify the batching and
81+
gradient rules. Composite operators are implemented using these primitive operators
82+
and other simpler composite operators. If we implement batching rules for every
83+
primitive operator, we automatically get the batching rules for composite operators.
84+
85+
There are two ways to add batching support for an operator:
86+
87+
- Manually write the batching rule. See for example the [batching rule for torch.dot](https://github.com/pytorch/pytorch/blob/b30ee35a6f141d3247a24fd09f96ea50a7e2b3c7/aten/src/ATen/functorch/BatchRulesLinearAlgebra.cpp#L25-L34)
88+
- Decompose operators using other operators for which we already have a
89+
batching rule. See for example the [batching rule for torch.vdot](https://github.com/pytorch/pytorch/blob/b30ee35a6f141d3247a24fd09f96ea50a7e2b3c7/aten/src/ATen/functorch/BatchRulesLinearAlgebra.cpp#L35-L37)
90+
91+
## Composite Compliance
92+
93+
As mentioned earlier, we obtain batching rules effortlessly for composite operators
94+
but this holds true only under certain constraints. These constraints include
95+
refraining from accessing the tensor's data pointer and avoiding the use of
96+
`out=` variants of the operators.
97+
For the full list of constraints, see [this documentation](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#composite-compliance).
98+
When all these hold for an operator, we say that it is composite compliant.
99+
100+
Unfortunately, operators that claim to be composite may occasionally deviate
101+
from these constraints. While such deviations may not pose issues when utilizing
102+
plain eager PyTorch, they can lead to complications when using `torch.func`
103+
transformations.
104+
105+
### Testing for Composite Compliance
106+
107+
We have tests to ensure that operators tagged as composite are indeed composite
108+
compliant. We test this by creating a new subclass `CompositeCompliantTensor`
109+
that utilizes the `__torch_dispatch__` mechanism. This mechanism is invoked for
110+
all operators in the testing, enabling us to detect any non-compliant behavior
111+
exhibited by an operator.
112+
113+
Our testing approach involves running tests on the actual [operator](https://github.com/pytorch/pytorch/blob/40b2c796dcae5768ff5de279b13914d5948fd414/test/test_ops.py#L1446),
114+
as well as their [backward formula](https://github.com/pytorch/pytorch/blob/40b2c796dcae5768ff5de279b13914d5948fd414/test/test_ops.py#L1459)
115+
and [forward AD formula](https://github.com/pytorch/pytorch/blob/40b2c796dcae5768ff5de279b13914d5948fd414/test/test_ops.py#L1476).
116+
Testing both the backward and forward formulas is crucial because we may encounter
117+
scenarios involving `vmap(vjp(fn))` or `vmap(jvp(fn))`.
118+
119+
## Support for `chunk_size` in `vmap` and `jacrev`
120+
121+
The computation of the Jacobian can be memory-intensive, and users have raised
122+
concerns about related issues, (e.g., [this one](https://github.com/pytorch/functorch/issues/680)).
123+
In response to these concerns, we have introduced a feature that allows for the
124+
calculation of `jacrev` and `vmap` in smaller, user-defined chunks, determined
125+
by the `chunk_size` argument. This adjustment serves to reduce the peak memory
126+
usage during the computation process. With this argument, users can specify the
127+
number of rows of the Jacobian to be computed at once. This enhancement was
128+
incorporated into the [jacrev PR](https://github.com/pytorch/pytorch/pull/89376)
129+
and the [vmap PR](https://github.com/pytorch/pytorch/pull/91157).
130+
131+
## Support for `linearize` transform
132+
133+
The "jvp" transform is designed to calculate both `f(x)` and the Jacobian-vector
134+
product. Consequently, even when one intends to compute the Jacobian-vector
135+
product for fixed inputs, the `jvp` transform still redundantly evaluates `f(x)`.
136+
To address such scenarios, the `linearize` transform comes into play. This
137+
transform proves valuable when multiple `jvp` computations are needed for
138+
constant inputs.
139+
140+
Note that, in order to implement this efficiently, `linearize` stores some
141+
intermediate computations, which can result in higher memory requirements compared
142+
to directly applying `jvp`. The `linearize` transform was implemented in this
143+
[PR](https://github.com/pytorch/pytorch/pull/94173).
144+
145+
## Support of `torch.func` transforms within `torch.compile`
146+
147+
PyTorch 2.0 introduced a JIT compiler under `torch.compile`, similar to `jax.jit`.
148+
This opened up the possibility of compiling the existing transforms to enhance
149+
their performance. To understand how these transforms can be compiled, it is
150+
essential to discuss the workings of the three layers within the compilation
151+
stack, namely `dynamo`, `aot_autograd`, and `inductor`.
152+
153+
The `dynamo` and `aot_autograd` layers primarily focus on capturing the
154+
computation graph and converting the captured operations into more basic operations.
155+
This captured graph is then passed to `inductor`, the compiler. `inductor` then applies
156+
various optimization passes before generating specialized code.
157+
158+
To gain insight into the different stages of this stack,
159+
let us compile a simple program in debug mode.
160+
161+
```python
162+
# Run this file with `TORCH_COMPILE_DEBUG=1`
163+
164+
import torch
165+
166+
def fn(x):
167+
return torch.sin(x) + torch.square(x)
168+
169+
torch.compile(fn)(torch.randn(4, 4))
170+
```
171+
172+
**dynamo**: The primary responsibility of dynamo is to trace the Python program
173+
and convert it into the FX graph format. The FX graph generated by `dynamo`
174+
represents PyTorch operations from the public API, such as `torch.sin`.
175+
Below, you can observe the graph captured by `dynamo`.
176+
177+
```python
178+
class GraphModule(torch.nn.Module):
179+
def forward(self, L_x_ : torch.Tensor):
180+
l_x_ = L_x_
181+
182+
# File: test/test_scratch.py:334, code: return torch.sin(x) + torch.square(x)
183+
sin = torch.sin(l_x_)
184+
square = torch.square(l_x_); l_x_ = None
185+
add = sin + square; sin = square = None
186+
return (add,)
187+
```
188+
189+
**aot_autograd**: `aot_autograd` retraces all PyTorch operations to
190+
produce a lower-level FX graph using `aten` functions (from the private API).
191+
Additionally, `aot_autograd` decomposes composite operations into primitive
192+
operations. For instance, a composite operation like `torch.square` is traced
193+
down to `aten.pow(x, 2)`.
194+
195+
Moreover, `aot_autograd` also manages the creation of the backward graph when
196+
requested. This is useful for transforms like `grad`, `vjp`, etc. Below, you can see the
197+
graph generated by `aot_autograd` for the above program.
198+
199+
```python
200+
def forward(self, arg0_1: f32[4, 4]):
201+
# File: test/test_scratch.py:334, code: return torch.sin(x) + torch.square(x)
202+
sin: f32[4, 4] = torch.ops.aten.sin.default(arg0_1)
203+
pow_1: f32[4, 4] = torch.ops.aten.pow.Tensor_Scalar(arg0_1, 2); arg0_1 = None
204+
add: f32[4, 4] = torch.ops.aten.add.Tensor(sin, pow_1); sin = pow_1 = None
205+
return (add,)
206+
207+
```
208+
209+
**inductor**: As discussed above, it is inductor's job to apply optimizations and
210+
generate specialized code. In this case, it has fused `sin` and `square` to run
211+
within the same `for`-loop. This allows the generated program to do more compute
212+
per read/write effectively improving the memory bandwidth utilization.
213+
214+
```cpp
215+
extern "C" void kernel(const float* in_ptr0, float* out_ptr0) {
216+
for (long i0 = 0L; i0 < 16L; i0 += 8L) {
217+
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + i0);
218+
auto tmp1 = tmp0.sin();
219+
auto tmp2 = tmp0 * tmp0;
220+
auto tmp3 = tmp1 + tmp2;
221+
tmp3.store(out_ptr0 + i0);
222+
}
223+
}
224+
```
225+
226+
### Teaching `dynamo` about `torch.func` transforms
227+
228+
Now that we have a basic understanding of how `torch.compile` works, let us
229+
delve into how we extended the support for `torch.func` transforms. Given that
230+
`aot_autograd` is already capable of tracing through the transforms, our task is
231+
to teach `dynamo` to validate whether the user-defined function intended for
232+
transformation is free of side effects affecting the global state or graph-breaks.
233+
In cases where the function meets these criteria, we can put
234+
the `torch.func` transform into the FX graph and delegate the remaining
235+
processing to the lower layers of the stack.
236+
237+
However, if the function cannot be successfully traced due to its failure to
238+
meet the above constraints, we fallback to the eager implementation, and this
239+
particular portion of the code remains uncompiled.
240+
241+
Let us have a look at what `dynamo` and `aot_autograd` generates when we compile a
242+
program with `grad`.
243+
244+
```python
245+
# Run this file with `TORCH_COMPILE_DEBUG=1`
246+
247+
import torch
248+
249+
def user_fn(x):
250+
return torch.sin(x)
251+
252+
def wrapper_fn(x):
253+
return torch.func.grad(user_fn)(x)
254+
255+
torch.compile(wrapper_fn)(torch.randn(()))
256+
257+
```
258+
259+
The output from `dynamo` is presented below. The initial `GraphModule` pertains
260+
to the `wrapper_fn`, clearly indicating a call to `grad` on the traced representation
261+
of the user's function intended for transformation. Subsequently, the second
262+
`GraphModule` corresponds to the function provided by the user. In this instance,
263+
our function didn't have side effects and graph-breaks. Thus, we were able to
264+
successfully trace through this program in one graph.
265+
266+
```python
267+
class GraphModule(torch.nn.Module):
268+
def forward(self, L_x_ : torch.Tensor):
269+
l_x_ = L_x_
270+
271+
# File: torch/_functorch/apis.py:363, code:
272+
# return eager_transforms.grad_impl(func, argnums, has_aux, args, kwargs)
273+
grad_body_0 = self.grad_body_0
274+
grad_proxy = torch.func.grad(grad_body_0, 0, False); grad_body_0 = None
275+
call = grad_proxy.__call__(l_x_); grad_proxy = l_x_ = None
276+
contiguous = call.contiguous(); call = None
277+
return (contiguous,)
278+
279+
class GraphModule(torch.nn.Module):
280+
def forward(self, l_x_):
281+
# No stacktrace found for following nodes
282+
_set_grad_enabled = torch._C._set_grad_enabled(True)
283+
284+
# File: test/test_scratch.py:382, code: return torch.sin(x)
285+
sin = torch.sin(l_x_); l_x_ = None
286+
287+
# No stacktrace found for following nodes
288+
_set_grad_enabled_1 = torch._C._set_grad_enabled(True)
289+
return sin
290+
```
291+
292+
The graph shown above is handed over to `aot_autograd` for the subsequent phase
293+
of the compilation process. `aot_autograd` performs a trace through the
294+
transformation, resulting in the generation of the transformed graph. This
295+
explains why we observe a call to `cos` instead of `sin`. `aot_autograd`
296+
has traced through the forward and backward graph, as we have applied the `grad` transform,
297+
then optimized away the forward computation as `grad` discards that value.
298+
299+
```python
300+
def forward(self, arg0_1: f32[]):
301+
# File: torch/_functorch/apis.py:363,
302+
# code: return eager_transforms.grad_impl(func, argnums, has_aux, args, kwargs)
303+
full: f32[] = torch.ops.aten.full.default([], 1, dtype = torch.float32,
304+
layout = torch.strided,
305+
device = device(type='cpu'),
306+
pin_memory = False)
307+
cos: f32[] = torch.ops.aten.cos.default(arg0_1); arg0_1 = None
308+
mul: f32[] = torch.ops.aten.mul.Tensor(full, cos); full = cos = None
309+
return (mul,)
310+
```
311+
312+
The inclusion of `torch.func` support within `torch.compile` is currently under
313+
active development. At present, our support extends
314+
to the compilation of `grad` and `vmap`. However, it is important to note that
315+
there are certain [limitations](https://pytorch.org/docs/main/torch.compiler_faq.html#limitations)
316+
that restrict the range of cases we can compile.
317+
318+
Looking ahead, our roadmap aims to extend the support for all transforms with
319+
minimal limitations, providing a more comprehensive compilation support for `torch.func`
320+
transforms
321+
322+
## Closing Remarks
323+
324+
This project was yet another instance of the tight collaboration between Quansight
325+
and Meta within PyTorch. In particular, we would like to thank Richard Zou and
326+
Horace He, the `torch.func` creators, for all the design discussions and
327+
guidance throughout these years.

0 commit comments

Comments
 (0)