|
| 1 | +--- |
| 2 | +title: 'Review: `torch.func` Contribution' |
| 3 | +published: September 22, 2023 |
| 4 | +author: kshiteej-kalambarkar |
| 5 | +description: '`torch.func` (previously known as `functorch`) is a PyTorch module |
| 6 | +designed to offer JAX-like transforms. Within this module, various higher-order |
| 7 | +functions, such as `grad`, `vmap`, and `vjp` are made accessible. These |
| 8 | +transforms help users to easily compute gradients for the parameters of their model |
| 9 | +or write batch-size agnostic code. The beauty of these transformations lies in |
| 10 | +their ability to compose with one another. Thanks to this, the process of |
| 11 | +calculating per-sample gradients becomes the straightforward application of |
| 12 | +`vmap(grad(model))`.' |
| 13 | +--- |
| 14 | + |
| 15 | +<base target="_blank" /> |
| 16 | + |
| 17 | +`torch.func` (previously known as `functorch`) is a PyTorch module designed to |
| 18 | +offer JAX-like transforms. Within this module, various higher-order |
| 19 | +functions, such as `grad`, `vmap`, and `vjp` are made accessible. These |
| 20 | +transforms help users to easily compute gradients for the parameters of their model |
| 21 | +or write batch-size agnostic code. The beauty of these transformations lies in |
| 22 | +their ability to compose with one another. Thanks to this, the process of |
| 23 | +calculating per-sample gradients becomes the straightforward application of |
| 24 | +`vmap(grad(model))`. |
| 25 | + |
| 26 | +Here are a few of the tasks we had the opportunity to tackle: |
| 27 | + |
| 28 | +## Adding Batching Rules for `vmap` |
| 29 | + |
| 30 | +`vmap` is a transformation that accepts a function that operates on non-batched |
| 31 | +tensors and returns a new function that operates on batched tensors. |
| 32 | +When processing a batched input, an additional |
| 33 | +dimension, denoted by `in_dims`, is introduced to indicate which dimension to |
| 34 | +apply the function over. Conceptually, it emulates a `for` loop that iterates |
| 35 | +through all the data points and stacks the results. Importantly, it performs this |
| 36 | +operation efficiently by pushing the `for` loop into the PyTorch operations, |
| 37 | +allowing the batches to run in parallel. |
| 38 | + |
| 39 | +Consider the following example: |
| 40 | + |
| 41 | +```python |
| 42 | +import torch |
| 43 | + |
| 44 | +# Written to handle only single sample. |
| 45 | +def my_simple_model(feature_vec, weight): |
| 46 | + return torch.dot(feature_vec, weight).relu() |
| 47 | + |
| 48 | +batch_size = 4 |
| 49 | +batched_inputs = torch.randn(batch_size, 3) |
| 50 | +weight = torch.randn(3) |
| 51 | + |
| 52 | +# For Loop version |
| 53 | +expected = [] |
| 54 | +for input in batched_inputs: |
| 55 | + expected.append(my_simple_model(input, weight)) |
| 56 | +expected = torch.stack(expected) |
| 57 | + |
| 58 | +# Vmap |
| 59 | +# `in_dims` specifies the dimension that should be mapped over. |
| 60 | +# In this case, we map only over 0-dim of `batched_inputs`. |
| 61 | +actual = torch.vmap(my_simple_model, in_dims=(0, None))(batched_inputs, weight) |
| 62 | + |
| 63 | +# Verify that the results match. |
| 64 | +torch.testing.assert_close(expected, actual) |
| 65 | +``` |
| 66 | + |
| 67 | +To support `vmap` for PyTorch operators, we need to specify the batching rule |
| 68 | +i.e. how to map the given function over a batched input. |
| 69 | +A batching rule is essentially a function which takes one or multiple batched |
| 70 | +inputs and computes the batched operation. In the above example to support |
| 71 | +`vmap` for `my_simple_model`, we need to know the batching rule for `torch.dot` |
| 72 | +and `torch.relu` to be able to vectorize our model. |
| 73 | +PyTorch has more than [2000 operators](https://dev-discuss.pytorch.org/t/where-do-the-2000-pytorch-operators-come-from-more-than-you-wanted-to-know/373) |
| 74 | +and we need to have coverage for all of them |
| 75 | +to support `vmap`. That being said, there is a `for`-loop |
| 76 | +fallback in case an operator is not supported so as not to crash the code. |
| 77 | + |
| 78 | +From the point of view of adding batching rules, PyTorch operators can be roughly |
| 79 | +categorized as primitive or composite. |
| 80 | +Primitive operators are the ones for which we specify the batching and |
| 81 | +gradient rules. Composite operators are implemented using these primitive operators |
| 82 | +and other simpler composite operators. If we implement batching rules for every |
| 83 | +primitive operator, we automatically get the batching rules for composite operators. |
| 84 | + |
| 85 | +There are two ways to add batching support for an operator: |
| 86 | + |
| 87 | +- Manually write the batching rule. See for example the [batching rule for torch.dot](https://github.com/pytorch/pytorch/blob/b30ee35a6f141d3247a24fd09f96ea50a7e2b3c7/aten/src/ATen/functorch/BatchRulesLinearAlgebra.cpp#L25-L34) |
| 88 | +- Decompose operators using other operators for which we already have a |
| 89 | + batching rule. See for example the [batching rule for torch.vdot](https://github.com/pytorch/pytorch/blob/b30ee35a6f141d3247a24fd09f96ea50a7e2b3c7/aten/src/ATen/functorch/BatchRulesLinearAlgebra.cpp#L35-L37) |
| 90 | + |
| 91 | +## Composite Compliance |
| 92 | + |
| 93 | +As mentioned earlier, we obtain batching rules effortlessly for composite operators |
| 94 | +but this holds true only under certain constraints. These constraints include |
| 95 | +refraining from accessing the tensor's data pointer and avoiding the use of |
| 96 | +`out=` variants of the operators. |
| 97 | +For the full list of constraints, see [this documentation](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#composite-compliance). |
| 98 | +When all these hold for an operator, we say that it is composite compliant. |
| 99 | + |
| 100 | +Unfortunately, operators that claim to be composite may occasionally deviate |
| 101 | +from these constraints. While such deviations may not pose issues when utilizing |
| 102 | +plain eager PyTorch, they can lead to complications when using `torch.func` |
| 103 | +transformations. |
| 104 | + |
| 105 | +### Testing for Composite Compliance |
| 106 | + |
| 107 | +We have tests to ensure that operators tagged as composite are indeed composite |
| 108 | +compliant. We test this by creating a new subclass `CompositeCompliantTensor` |
| 109 | +that utilizes the `__torch_dispatch__` mechanism. This mechanism is invoked for |
| 110 | +all operators in the testing, enabling us to detect any non-compliant behavior |
| 111 | +exhibited by an operator. |
| 112 | + |
| 113 | +Our testing approach involves running tests on the actual [operator](https://github.com/pytorch/pytorch/blob/40b2c796dcae5768ff5de279b13914d5948fd414/test/test_ops.py#L1446), |
| 114 | +as well as their [backward formula](https://github.com/pytorch/pytorch/blob/40b2c796dcae5768ff5de279b13914d5948fd414/test/test_ops.py#L1459) |
| 115 | +and [forward AD formula](https://github.com/pytorch/pytorch/blob/40b2c796dcae5768ff5de279b13914d5948fd414/test/test_ops.py#L1476). |
| 116 | +Testing both the backward and forward formulas is crucial because we may encounter |
| 117 | +scenarios involving `vmap(vjp(fn))` or `vmap(jvp(fn))`. |
| 118 | + |
| 119 | +## Support for `chunk_size` in `vmap` and `jacrev` |
| 120 | + |
| 121 | +The computation of the Jacobian can be memory-intensive, and users have raised |
| 122 | +concerns about related issues, (e.g., [this one](https://github.com/pytorch/functorch/issues/680)). |
| 123 | +In response to these concerns, we have introduced a feature that allows for the |
| 124 | +calculation of `jacrev` and `vmap` in smaller, user-defined chunks, determined |
| 125 | +by the `chunk_size` argument. This adjustment serves to reduce the peak memory |
| 126 | +usage during the computation process. With this argument, users can specify the |
| 127 | +number of rows of the Jacobian to be computed at once. This enhancement was |
| 128 | +incorporated into the [jacrev PR](https://github.com/pytorch/pytorch/pull/89376) |
| 129 | +and the [vmap PR](https://github.com/pytorch/pytorch/pull/91157). |
| 130 | + |
| 131 | +## Support for `linearize` transform |
| 132 | + |
| 133 | +The "jvp" transform is designed to calculate both `f(x)` and the Jacobian-vector |
| 134 | +product. Consequently, even when one intends to compute the Jacobian-vector |
| 135 | +product for fixed inputs, the `jvp` transform still redundantly evaluates `f(x)`. |
| 136 | +To address such scenarios, the `linearize` transform comes into play. This |
| 137 | +transform proves valuable when multiple `jvp` computations are needed for |
| 138 | +constant inputs. |
| 139 | + |
| 140 | +Note that, in order to implement this efficiently, `linearize` stores some |
| 141 | +intermediate computations, which can result in higher memory requirements compared |
| 142 | +to directly applying `jvp`. The `linearize` transform was implemented in this |
| 143 | +[PR](https://github.com/pytorch/pytorch/pull/94173). |
| 144 | + |
| 145 | +## Support of `torch.func` transforms within `torch.compile` |
| 146 | + |
| 147 | +PyTorch 2.0 introduced a JIT compiler under `torch.compile`, similar to `jax.jit`. |
| 148 | +This opened up the possibility of compiling the existing transforms to enhance |
| 149 | +their performance. To understand how these transforms can be compiled, it is |
| 150 | +essential to discuss the workings of the three layers within the compilation |
| 151 | +stack, namely `dynamo`, `aot_autograd`, and `inductor`. |
| 152 | + |
| 153 | +The `dynamo` and `aot_autograd` layers primarily focus on capturing the |
| 154 | +computation graph and converting the captured operations into more basic operations. |
| 155 | +This captured graph is then passed to `inductor`, the compiler. `inductor` then applies |
| 156 | +various optimization passes before generating specialized code. |
| 157 | + |
| 158 | +To gain insight into the different stages of this stack, |
| 159 | +let us compile a simple program in debug mode. |
| 160 | + |
| 161 | +```python |
| 162 | +# Run this file with `TORCH_COMPILE_DEBUG=1` |
| 163 | + |
| 164 | +import torch |
| 165 | + |
| 166 | +def fn(x): |
| 167 | + return torch.sin(x) + torch.square(x) |
| 168 | + |
| 169 | +torch.compile(fn)(torch.randn(4, 4)) |
| 170 | +``` |
| 171 | + |
| 172 | +**dynamo**: The primary responsibility of dynamo is to trace the Python program |
| 173 | +and convert it into the FX graph format. The FX graph generated by `dynamo` |
| 174 | +represents PyTorch operations from the public API, such as `torch.sin`. |
| 175 | +Below, you can observe the graph captured by `dynamo`. |
| 176 | + |
| 177 | +```python |
| 178 | +class GraphModule(torch.nn.Module): |
| 179 | + def forward(self, L_x_ : torch.Tensor): |
| 180 | + l_x_ = L_x_ |
| 181 | + |
| 182 | + # File: test/test_scratch.py:334, code: return torch.sin(x) + torch.square(x) |
| 183 | + sin = torch.sin(l_x_) |
| 184 | + square = torch.square(l_x_); l_x_ = None |
| 185 | + add = sin + square; sin = square = None |
| 186 | + return (add,) |
| 187 | +``` |
| 188 | + |
| 189 | +**aot_autograd**: `aot_autograd` retraces all PyTorch operations to |
| 190 | +produce a lower-level FX graph using `aten` functions (from the private API). |
| 191 | +Additionally, `aot_autograd` decomposes composite operations into primitive |
| 192 | +operations. For instance, a composite operation like `torch.square` is traced |
| 193 | +down to `aten.pow(x, 2)`. |
| 194 | + |
| 195 | +Moreover, `aot_autograd` also manages the creation of the backward graph when |
| 196 | +requested. This is useful for transforms like `grad`, `vjp`, etc. Below, you can see the |
| 197 | +graph generated by `aot_autograd` for the above program. |
| 198 | + |
| 199 | +```python |
| 200 | +def forward(self, arg0_1: f32[4, 4]): |
| 201 | + # File: test/test_scratch.py:334, code: return torch.sin(x) + torch.square(x) |
| 202 | + sin: f32[4, 4] = torch.ops.aten.sin.default(arg0_1) |
| 203 | + pow_1: f32[4, 4] = torch.ops.aten.pow.Tensor_Scalar(arg0_1, 2); arg0_1 = None |
| 204 | + add: f32[4, 4] = torch.ops.aten.add.Tensor(sin, pow_1); sin = pow_1 = None |
| 205 | + return (add,) |
| 206 | + |
| 207 | +``` |
| 208 | + |
| 209 | +**inductor**: As discussed above, it is inductor's job to apply optimizations and |
| 210 | +generate specialized code. In this case, it has fused `sin` and `square` to run |
| 211 | +within the same `for`-loop. This allows the generated program to do more compute |
| 212 | +per read/write effectively improving the memory bandwidth utilization. |
| 213 | + |
| 214 | +```cpp |
| 215 | +extern "C" void kernel(const float* in_ptr0, float* out_ptr0) { |
| 216 | + for (long i0 = 0L; i0 < 16L; i0 += 8L) { |
| 217 | + auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + i0); |
| 218 | + auto tmp1 = tmp0.sin(); |
| 219 | + auto tmp2 = tmp0 * tmp0; |
| 220 | + auto tmp3 = tmp1 + tmp2; |
| 221 | + tmp3.store(out_ptr0 + i0); |
| 222 | + } |
| 223 | +} |
| 224 | +``` |
| 225 | +
|
| 226 | +### Teaching `dynamo` about `torch.func` transforms |
| 227 | +
|
| 228 | +Now that we have a basic understanding of how `torch.compile` works, let us |
| 229 | +delve into how we extended the support for `torch.func` transforms. Given that |
| 230 | +`aot_autograd` is already capable of tracing through the transforms, our task is |
| 231 | +to teach `dynamo` to validate whether the user-defined function intended for |
| 232 | +transformation is free of side effects affecting the global state or graph-breaks. |
| 233 | +In cases where the function meets these criteria, we can put |
| 234 | +the `torch.func` transform into the FX graph and delegate the remaining |
| 235 | +processing to the lower layers of the stack. |
| 236 | +
|
| 237 | +However, if the function cannot be successfully traced due to its failure to |
| 238 | +meet the above constraints, we fallback to the eager implementation, and this |
| 239 | +particular portion of the code remains uncompiled. |
| 240 | +
|
| 241 | +Let us have a look at what `dynamo` and `aot_autograd` generates when we compile a |
| 242 | +program with `grad`. |
| 243 | +
|
| 244 | +```python |
| 245 | +# Run this file with `TORCH_COMPILE_DEBUG=1` |
| 246 | +
|
| 247 | +import torch |
| 248 | +
|
| 249 | +def user_fn(x): |
| 250 | + return torch.sin(x) |
| 251 | +
|
| 252 | +def wrapper_fn(x): |
| 253 | + return torch.func.grad(user_fn)(x) |
| 254 | +
|
| 255 | +torch.compile(wrapper_fn)(torch.randn(())) |
| 256 | +
|
| 257 | +``` |
| 258 | + |
| 259 | +The output from `dynamo` is presented below. The initial `GraphModule` pertains |
| 260 | +to the `wrapper_fn`, clearly indicating a call to `grad` on the traced representation |
| 261 | +of the user's function intended for transformation. Subsequently, the second |
| 262 | +`GraphModule` corresponds to the function provided by the user. In this instance, |
| 263 | +our function didn't have side effects and graph-breaks. Thus, we were able to |
| 264 | +successfully trace through this program in one graph. |
| 265 | + |
| 266 | +```python |
| 267 | +class GraphModule(torch.nn.Module): |
| 268 | + def forward(self, L_x_ : torch.Tensor): |
| 269 | + l_x_ = L_x_ |
| 270 | + |
| 271 | + # File: torch/_functorch/apis.py:363, code: |
| 272 | + # return eager_transforms.grad_impl(func, argnums, has_aux, args, kwargs) |
| 273 | + grad_body_0 = self.grad_body_0 |
| 274 | + grad_proxy = torch.func.grad(grad_body_0, 0, False); grad_body_0 = None |
| 275 | + call = grad_proxy.__call__(l_x_); grad_proxy = l_x_ = None |
| 276 | + contiguous = call.contiguous(); call = None |
| 277 | + return (contiguous,) |
| 278 | + |
| 279 | + class GraphModule(torch.nn.Module): |
| 280 | + def forward(self, l_x_): |
| 281 | + # No stacktrace found for following nodes |
| 282 | + _set_grad_enabled = torch._C._set_grad_enabled(True) |
| 283 | + |
| 284 | + # File: test/test_scratch.py:382, code: return torch.sin(x) |
| 285 | + sin = torch.sin(l_x_); l_x_ = None |
| 286 | + |
| 287 | + # No stacktrace found for following nodes |
| 288 | + _set_grad_enabled_1 = torch._C._set_grad_enabled(True) |
| 289 | + return sin |
| 290 | +``` |
| 291 | +
|
| 292 | +The graph shown above is handed over to `aot_autograd` for the subsequent phase |
| 293 | +of the compilation process. `aot_autograd` performs a trace through the |
| 294 | +transformation, resulting in the generation of the transformed graph. This |
| 295 | +explains why we observe a call to `cos` instead of `sin`. `aot_autograd` |
| 296 | +has traced through the forward and backward graph, as we have applied the `grad` transform, |
| 297 | +then optimized away the forward computation as `grad` discards that value. |
| 298 | +
|
| 299 | +```python |
| 300 | +def forward(self, arg0_1: f32[]): |
| 301 | + # File: torch/_functorch/apis.py:363, |
| 302 | + # code: return eager_transforms.grad_impl(func, argnums, has_aux, args, kwargs) |
| 303 | + full: f32[] = torch.ops.aten.full.default([], 1, dtype = torch.float32, |
| 304 | + layout = torch.strided, |
| 305 | + device = device(type='cpu'), |
| 306 | + pin_memory = False) |
| 307 | + cos: f32[] = torch.ops.aten.cos.default(arg0_1); arg0_1 = None |
| 308 | + mul: f32[] = torch.ops.aten.mul.Tensor(full, cos); full = cos = None |
| 309 | + return (mul,) |
| 310 | +``` |
| 311 | + |
| 312 | +The inclusion of `torch.func` support within `torch.compile` is currently under |
| 313 | +active development. At present, our support extends |
| 314 | +to the compilation of `grad` and `vmap`. However, it is important to note that |
| 315 | +there are certain [limitations](https://pytorch.org/docs/main/torch.compiler_faq.html#limitations) |
| 316 | +that restrict the range of cases we can compile. |
| 317 | + |
| 318 | +Looking ahead, our roadmap aims to extend the support for all transforms with |
| 319 | +minimal limitations, providing a more comprehensive compilation support for `torch.func` |
| 320 | +transforms |
| 321 | + |
| 322 | +## Closing Remarks |
| 323 | + |
| 324 | +This project was yet another instance of the tight collaboration between Quansight |
| 325 | +and Meta within PyTorch. In particular, we would like to thank Richard Zou and |
| 326 | +Horace He, the `torch.func` creators, for all the design discussions and |
| 327 | +guidance throughout these years. |
0 commit comments