|
7 | 7 | "source": [
|
8 | 8 | "# Whirlwind Tour\n",
|
9 | 9 | "\n",
|
10 |
| - "functorch is [JAX](https://github.com/google/jax)-like composable function transforms for PyTorch. In this whirlwind tour, we'll introduce all the functorch transforms.\n", |
| 10 | + "\n", |
| 11 | + "## What is functorch?\n", |
| 12 | + "\n", |
| 13 | + "functorch is a library for [JAX](https://github.com/google/jax)-like composable function transforms in PyTorch.\n", |
| 14 | + "- A \"function transform\" is a higher-order function that accepts a numerical function and returns a new function that computes a different quantity.\n", |
| 15 | + "- functorch has auto-differentiation transforms (`grad(f)` returns a function that computes the gradient of `f`), a vectorization/batching transform (`vmap(f)` returns a function that computes `f` over batches of inputs), and others.\n", |
| 16 | + "- These function transforms can compose with each other arbitrarily. For example, composing `vmap(grad(f))` computes a quantity called per-sample-gradients that stock PyTorch cannot efficiently compute today.\n", |
| 17 | + "\n", |
| 18 | + "Furthermore, we also provide an experimental compilation transform in the `functorch.compile` namespace. Our compilation transform, named AOT (ahead-of-time) Autograd, returns to you an [FX graph](https://pytorch.org/docs/stable/fx.html) (that optionally contains a backward pass), of which compilation via various backends is one path you can take.\n", |
11 | 19 | "\n",
|
12 | 20 | "\n",
|
13 | 21 | "## Why composable function transforms?\n",
|
|
18 | 26 | "- efficiently computing Jacobians and Hessians\n",
|
19 | 27 | "- efficiently computing batched Jacobians and Hessians\n",
|
20 | 28 | "\n",
|
21 |
| - "Composing `vmap`, `grad`, `vjp`, and `jvp` transforms allows us to express the above without designing a separate subsystem for each. This idea of composable function transforms comes from the [JAX framework](https://github.com/google/jax).\n", |
| 29 | + "Composing `vmap`, `grad`, `vjp`, and `jvp` transforms allows us to express the above without designing a separate subsystem for each.\n", |
22 | 30 | "\n",
|
23 | 31 | "## What are the transforms?\n",
|
24 | 32 | "\n",
|
25 |
| - "Right now, we support the following transforms:\n", |
26 |
| - "\n", |
27 |
| - "- `grad`, `vjp`, `jvp`,\n", |
28 |
| - "- `jacrev`, `jacfwd`, `hessian`\n", |
29 |
| - "- `vmap`\n", |
| 33 | + "### grad (gradient computation)\n", |
30 | 34 | "\n",
|
31 |
| - "Furthermore, we have some utilities for working with PyTorch modules.\n", |
32 |
| - "- `make_functional(model)`\n", |
33 |
| - "- `make_functional_with_buffers(model)`\n", |
34 |
| - "\n", |
35 |
| - "### vmap\n", |
36 |
| - "\n", |
37 |
| - "Note: vmap imposes restrictions on the code that it can be used on. For more details, please read its docstring.\n", |
38 |
| - "\n", |
39 |
| - "`vmap(func)(*inputs)` is a transform that adds a dimension to all Tensor operations in `func`. `vmap(func)` returns a new function that maps `func` over some dimension (default: 0) of each Tensor in inputs.\n", |
40 |
| - "\n", |
41 |
| - "vmap is useful for hiding batch dimensions: one can write a function func that runs on examples and then lift it to a function that can take batches of examples with `vmap(func)`, leading to a simpler modeling experience:" |
| 35 | + "`grad(func)` is our gradient computation transform. It returns a new function that computes the gradients of `func`. It assumes `func` returns a single-element Tensor and by default it computes the gradients of the output of `func` w.r.t. to the first input." |
42 | 36 | ]
|
43 | 37 | },
|
44 | 38 | {
|
45 | 39 | "cell_type": "code",
|
46 |
| - "execution_count": 1, |
| 40 | + "execution_count": null, |
47 | 41 | "id": "f920b923",
|
48 | 42 | "metadata": {},
|
49 | 43 | "outputs": [],
|
50 | 44 | "source": [
|
51 |
| - "import torch\n", |
52 |
| - "from functorch import vmap\n", |
53 |
| - "batch_size, feature_size = 3, 5\n", |
54 |
| - "weights = torch.randn(feature_size, requires_grad=True)\n", |
55 |
| - "\n", |
56 |
| - "def model(feature_vec):\n", |
57 |
| - " # Very simple linear model with activation\n", |
58 |
| - " assert feature_vec.dim() == 1\n", |
59 |
| - " return feature_vec.dot(weights).relu()\n", |
| 45 | + "from functorch import grad\n", |
| 46 | + "x = torch.randn([])\n", |
| 47 | + "cos_x = grad(lambda x: torch.sin(x))(x)\n", |
| 48 | + "assert torch.allclose(cos_x, x.cos())\n", |
60 | 49 | "\n",
|
61 |
| - "examples = torch.randn(batch_size, feature_size)\n", |
62 |
| - "result = vmap(model)(examples)" |
| 50 | + "# Second-order gradients\n", |
| 51 | + "neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x)\n", |
| 52 | + "assert torch.allclose(neg_sin_x, -x.sin())" |
63 | 53 | ]
|
64 | 54 | },
|
65 | 55 | {
|
66 | 56 | "cell_type": "markdown",
|
67 | 57 | "id": "ef3b2d85",
|
68 | 58 | "metadata": {},
|
69 | 59 | "source": [
|
70 |
| - "### grad\n", |
| 60 | + "### vmap (auto-vectorization)\n", |
| 61 | + "\n", |
| 62 | + "Note: vmap imposes restrictions on the code that it can be used on. For more details, please read its docstring.\n", |
| 63 | + "\n", |
| 64 | + "`vmap(func)(*inputs)` is a transform that adds a dimension to all Tensor operations in `func`. `vmap(func)` returns a new function that maps `func` over some dimension (default: 0) of each Tensor in inputs.\n", |
71 | 65 | "\n",
|
72 |
| - "`grad(func)(*inputs)` assumes `func` returns a single-element Tensor. By default, it computes the gradients of the output of `func` w.r.t. to `inputs[0]`." |
| 66 | + "vmap is useful for hiding batch dimensions: one can write a function func that runs on examples and then lift it to a function that can take batches of examples with `vmap(func)`, leading to a simpler modeling experience:" |
73 | 67 | ]
|
74 | 68 | },
|
75 | 69 | {
|
76 | 70 | "cell_type": "code",
|
77 |
| - "execution_count": 2, |
| 71 | + "execution_count": null, |
78 | 72 | "id": "6ebac649",
|
79 | 73 | "metadata": {},
|
80 | 74 | "outputs": [],
|
81 | 75 | "source": [
|
82 |
| - "from functorch import grad\n", |
83 |
| - "x = torch.randn([])\n", |
84 |
| - "cos_x = grad(lambda x: torch.sin(x))(x)\n", |
85 |
| - "assert torch.allclose(cos_x, x.cos())\n", |
| 76 | + "import torch\n", |
| 77 | + "from functorch import vmap\n", |
| 78 | + "batch_size, feature_size = 3, 5\n", |
| 79 | + "weights = torch.randn(feature_size, requires_grad=True)\n", |
86 | 80 | "\n",
|
87 |
| - "# Second-order gradients\n", |
88 |
| - "neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x)\n", |
89 |
| - "assert torch.allclose(neg_sin_x, -x.sin())" |
| 81 | + "def model(feature_vec):\n", |
| 82 | + " # Very simple linear model with activation\n", |
| 83 | + " assert feature_vec.dim() == 1\n", |
| 84 | + " return feature_vec.dot(weights).relu()\n", |
| 85 | + "\n", |
| 86 | + "examples = torch.randn(batch_size, feature_size)\n", |
| 87 | + "result = vmap(model)(examples)" |
90 | 88 | ]
|
91 | 89 | },
|
92 | 90 | {
|
93 | 91 | "cell_type": "markdown",
|
94 | 92 | "id": "5161e6d2",
|
95 | 93 | "metadata": {},
|
96 | 94 | "source": [
|
97 |
| - "When composed with vmap, grad can be used to compute per-sample-gradients:" |
| 95 | + "When composed with `grad`, `vmap` can be used to compute per-sample-gradients:" |
98 | 96 | ]
|
99 | 97 | },
|
100 | 98 | {
|
101 | 99 | "cell_type": "code",
|
102 |
| - "execution_count": 3, |
| 100 | + "execution_count": null, |
103 | 101 | "id": "ffb2fcb1",
|
104 | 102 | "metadata": {},
|
105 | 103 | "outputs": [],
|
|
128 | 126 | "id": "11d711af",
|
129 | 127 | "metadata": {},
|
130 | 128 | "source": [
|
131 |
| - "### vjp\n", |
| 129 | + "### vjp (vector-Jacobian product)\n", |
132 | 130 | "\n",
|
133 |
| - "The `vjp` transform applies `func` to `inputs` and returns a new function that computes vjps given some `cotangents` Tensors." |
| 131 | + "The `vjp` transform applies `func` to `inputs` and returns a new function that computes the vector-Jacobian product (vjp) given some `cotangents` Tensors." |
134 | 132 | ]
|
135 | 133 | },
|
136 | 134 | {
|
137 | 135 | "cell_type": "code",
|
138 |
| - "execution_count": 4, |
| 136 | + "execution_count": null, |
139 | 137 | "id": "ad48f9d4",
|
140 | 138 | "metadata": {},
|
141 | 139 | "outputs": [],
|
|
154 | 152 | "id": "e0221270",
|
155 | 153 | "metadata": {},
|
156 | 154 | "source": [
|
157 |
| - "### jvp\n", |
| 155 | + "### jvp (Jacobian-vector product)\n", |
158 | 156 | "\n",
|
159 | 157 | "The `jvp` transforms computes Jacobian-vector-products and is also known as \"forward-mode AD\". It is not a higher-order function unlike most other transforms, but it returns the outputs of `func(inputs)` as well as the jvps."
|
160 | 158 | ]
|
161 | 159 | },
|
162 | 160 | {
|
163 | 161 | "cell_type": "code",
|
164 |
| - "execution_count": 5, |
| 162 | + "execution_count": null, |
165 | 163 | "id": "f3772f43",
|
166 | 164 | "metadata": {},
|
167 | 165 | "outputs": [],
|
|
187 | 185 | },
|
188 | 186 | {
|
189 | 187 | "cell_type": "code",
|
190 |
| - "execution_count": 6, |
| 188 | + "execution_count": null, |
191 | 189 | "id": "20f53be2",
|
192 | 190 | "metadata": {},
|
193 | 191 | "outputs": [],
|
|
209 | 207 | },
|
210 | 208 | {
|
211 | 209 | "cell_type": "code",
|
212 |
| - "execution_count": 7, |
| 210 | + "execution_count": null, |
213 | 211 | "id": "97d6c382",
|
214 | 212 | "metadata": {},
|
215 | 213 | "outputs": [],
|
|
229 | 227 | },
|
230 | 228 | {
|
231 | 229 | "cell_type": "code",
|
232 |
| - "execution_count": 8, |
| 230 | + "execution_count": null, |
233 | 231 | "id": "a8c1dedb",
|
234 | 232 | "metadata": {},
|
235 | 233 | "outputs": [],
|
|
251 | 249 | },
|
252 | 250 | {
|
253 | 251 | "cell_type": "code",
|
254 |
| - "execution_count": 9, |
| 252 | + "execution_count": null, |
255 | 253 | "id": "1e511139",
|
256 | 254 | "metadata": {},
|
257 | 255 | "outputs": [],
|
|
274 | 272 | },
|
275 | 273 | {
|
276 | 274 | "cell_type": "code",
|
277 |
| - "execution_count": 10, |
| 275 | + "execution_count": null, |
278 | 276 | "id": "fd1765df",
|
279 | 277 | "metadata": {},
|
280 | 278 | "outputs": [],
|
|
315 | 313 | "name": "python",
|
316 | 314 | "nbconvert_exporter": "python",
|
317 | 315 | "pygments_lexer": "ipython3",
|
318 |
| - "version": "3.9.4" |
| 316 | + "version": "3.7.4" |
319 | 317 | }
|
320 | 318 | },
|
321 | 319 | "nbformat": 4,
|
|
0 commit comments