Skip to content

Commit 4afe33b

Browse files
authored
Better intro to functorch (#688)
For our docs landing page. Fixes #605
1 parent c2016e4 commit 4afe33b

File tree

2 files changed

+67
-58
lines changed

2 files changed

+67
-58
lines changed

docs/source/index.rst

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,29 @@ functorch
77

88
functorch is `JAX-like <https://github.com/google/jax>`_ composable function transforms for PyTorch.
99

10-
It aims to provide composable vmap and grad transforms that work with PyTorch modules
11-
and PyTorch autograd with good eager-mode performance.
12-
1310
.. note::
14-
This library is currently in [beta](https://pytorch.org/blog/pytorch-feature-classification-changes/#beta).
11+
This library is currently in `beta <https://pytorch.org/blog/pytorch-feature-classification-changes/#beta>`_.
1512
What this means is that the features generally work (unless otherwise documented)
1613
and we (the PyTorch team) are committed to bringing this library forward. However, the APIs
1714
may change under user feedback and we don't have full coverage over PyTorch operations.
1815

1916
If you have suggestions on the API or use-cases you'd like to be covered, please
2017
open an github issue or reach out. We'd love to hear about how you're using the library.
2118

19+
What are composable function transforms?
20+
----------------------------------------
21+
22+
- A "function transform" is a higher-order function that accepts a numerical function
23+
and returns a new function that computes a different quantity.
24+
25+
- functorch has auto-differentiation transforms (``grad(f)`` returns a function that
26+
computes the gradient of ``f``), a vectorization/batching transform (``vmap(f)``
27+
returns a function that computes ``f`` over batches of inputs), and others.
28+
29+
- These function transforms can compose with each other arbitrarily. For example,
30+
composing ``vmap(grad(f))`` computes a quantity called per-sample-gradients that
31+
stock PyTorch cannot efficiently compute today.
32+
2233
Why composable function transforms?
2334
-----------------------------------
2435

@@ -36,7 +47,7 @@ This idea of composable function transforms comes from the `JAX framework <https
3647
Read More
3748
---------
3849

39-
For a whirlwind tour of how to use the transforms, please check out `this section in our README <https://github.com/pytorch/functorch/blob/main/README.md#what-are-the-transforms>`_. For installation instructions or the API reference, please check below.
50+
Check out our `whirlwind tour <whirlwind_tour>`_ or some of our tutorials mentioned below.
4051

4152

4253
.. toctree::

notebooks/whirlwind_tour.ipynb

Lines changed: 51 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,15 @@
77
"source": [
88
"# Whirlwind Tour\n",
99
"\n",
10-
"functorch is [JAX](https://github.com/google/jax)-like composable function transforms for PyTorch. In this whirlwind tour, we'll introduce all the functorch transforms.\n",
10+
"\n",
11+
"## What is functorch?\n",
12+
"\n",
13+
"functorch is a library for [JAX](https://github.com/google/jax)-like composable function transforms in PyTorch.\n",
14+
"- A \"function transform\" is a higher-order function that accepts a numerical function and returns a new function that computes a different quantity.\n",
15+
"- functorch has auto-differentiation transforms (`grad(f)` returns a function that computes the gradient of `f`), a vectorization/batching transform (`vmap(f)` returns a function that computes `f` over batches of inputs), and others.\n",
16+
"- These function transforms can compose with each other arbitrarily. For example, composing `vmap(grad(f))` computes a quantity called per-sample-gradients that stock PyTorch cannot efficiently compute today.\n",
17+
"\n",
18+
"Furthermore, we also provide an experimental compilation transform in the `functorch.compile` namespace. Our compilation transform, named AOT (ahead-of-time) Autograd, returns to you an [FX graph](https://pytorch.org/docs/stable/fx.html) (that optionally contains a backward pass), of which compilation via various backends is one path you can take.\n",
1119
"\n",
1220
"\n",
1321
"## Why composable function transforms?\n",
@@ -18,88 +26,78 @@
1826
"- efficiently computing Jacobians and Hessians\n",
1927
"- efficiently computing batched Jacobians and Hessians\n",
2028
"\n",
21-
"Composing `vmap`, `grad`, `vjp`, and `jvp` transforms allows us to express the above without designing a separate subsystem for each. This idea of composable function transforms comes from the [JAX framework](https://github.com/google/jax).\n",
29+
"Composing `vmap`, `grad`, `vjp`, and `jvp` transforms allows us to express the above without designing a separate subsystem for each.\n",
2230
"\n",
2331
"## What are the transforms?\n",
2432
"\n",
25-
"Right now, we support the following transforms:\n",
26-
"\n",
27-
"- `grad`, `vjp`, `jvp`,\n",
28-
"- `jacrev`, `jacfwd`, `hessian`\n",
29-
"- `vmap`\n",
33+
"### grad (gradient computation)\n",
3034
"\n",
31-
"Furthermore, we have some utilities for working with PyTorch modules.\n",
32-
"- `make_functional(model)`\n",
33-
"- `make_functional_with_buffers(model)`\n",
34-
"\n",
35-
"### vmap\n",
36-
"\n",
37-
"Note: vmap imposes restrictions on the code that it can be used on. For more details, please read its docstring.\n",
38-
"\n",
39-
"`vmap(func)(*inputs)` is a transform that adds a dimension to all Tensor operations in `func`. `vmap(func)` returns a new function that maps `func` over some dimension (default: 0) of each Tensor in inputs.\n",
40-
"\n",
41-
"vmap is useful for hiding batch dimensions: one can write a function func that runs on examples and then lift it to a function that can take batches of examples with `vmap(func)`, leading to a simpler modeling experience:"
35+
"`grad(func)` is our gradient computation transform. It returns a new function that computes the gradients of `func`. It assumes `func` returns a single-element Tensor and by default it computes the gradients of the output of `func` w.r.t. to the first input."
4236
]
4337
},
4438
{
4539
"cell_type": "code",
46-
"execution_count": 1,
40+
"execution_count": null,
4741
"id": "f920b923",
4842
"metadata": {},
4943
"outputs": [],
5044
"source": [
51-
"import torch\n",
52-
"from functorch import vmap\n",
53-
"batch_size, feature_size = 3, 5\n",
54-
"weights = torch.randn(feature_size, requires_grad=True)\n",
55-
"\n",
56-
"def model(feature_vec):\n",
57-
" # Very simple linear model with activation\n",
58-
" assert feature_vec.dim() == 1\n",
59-
" return feature_vec.dot(weights).relu()\n",
45+
"from functorch import grad\n",
46+
"x = torch.randn([])\n",
47+
"cos_x = grad(lambda x: torch.sin(x))(x)\n",
48+
"assert torch.allclose(cos_x, x.cos())\n",
6049
"\n",
61-
"examples = torch.randn(batch_size, feature_size)\n",
62-
"result = vmap(model)(examples)"
50+
"# Second-order gradients\n",
51+
"neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x)\n",
52+
"assert torch.allclose(neg_sin_x, -x.sin())"
6353
]
6454
},
6555
{
6656
"cell_type": "markdown",
6757
"id": "ef3b2d85",
6858
"metadata": {},
6959
"source": [
70-
"### grad\n",
60+
"### vmap (auto-vectorization)\n",
61+
"\n",
62+
"Note: vmap imposes restrictions on the code that it can be used on. For more details, please read its docstring.\n",
63+
"\n",
64+
"`vmap(func)(*inputs)` is a transform that adds a dimension to all Tensor operations in `func`. `vmap(func)` returns a new function that maps `func` over some dimension (default: 0) of each Tensor in inputs.\n",
7165
"\n",
72-
"`grad(func)(*inputs)` assumes `func` returns a single-element Tensor. By default, it computes the gradients of the output of `func` w.r.t. to `inputs[0]`."
66+
"vmap is useful for hiding batch dimensions: one can write a function func that runs on examples and then lift it to a function that can take batches of examples with `vmap(func)`, leading to a simpler modeling experience:"
7367
]
7468
},
7569
{
7670
"cell_type": "code",
77-
"execution_count": 2,
71+
"execution_count": null,
7872
"id": "6ebac649",
7973
"metadata": {},
8074
"outputs": [],
8175
"source": [
82-
"from functorch import grad\n",
83-
"x = torch.randn([])\n",
84-
"cos_x = grad(lambda x: torch.sin(x))(x)\n",
85-
"assert torch.allclose(cos_x, x.cos())\n",
76+
"import torch\n",
77+
"from functorch import vmap\n",
78+
"batch_size, feature_size = 3, 5\n",
79+
"weights = torch.randn(feature_size, requires_grad=True)\n",
8680
"\n",
87-
"# Second-order gradients\n",
88-
"neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x)\n",
89-
"assert torch.allclose(neg_sin_x, -x.sin())"
81+
"def model(feature_vec):\n",
82+
" # Very simple linear model with activation\n",
83+
" assert feature_vec.dim() == 1\n",
84+
" return feature_vec.dot(weights).relu()\n",
85+
"\n",
86+
"examples = torch.randn(batch_size, feature_size)\n",
87+
"result = vmap(model)(examples)"
9088
]
9189
},
9290
{
9391
"cell_type": "markdown",
9492
"id": "5161e6d2",
9593
"metadata": {},
9694
"source": [
97-
"When composed with vmap, grad can be used to compute per-sample-gradients:"
95+
"When composed with `grad`, `vmap` can be used to compute per-sample-gradients:"
9896
]
9997
},
10098
{
10199
"cell_type": "code",
102-
"execution_count": 3,
100+
"execution_count": null,
103101
"id": "ffb2fcb1",
104102
"metadata": {},
105103
"outputs": [],
@@ -128,14 +126,14 @@
128126
"id": "11d711af",
129127
"metadata": {},
130128
"source": [
131-
"### vjp\n",
129+
"### vjp (vector-Jacobian product)\n",
132130
"\n",
133-
"The `vjp` transform applies `func` to `inputs` and returns a new function that computes vjps given some `cotangents` Tensors."
131+
"The `vjp` transform applies `func` to `inputs` and returns a new function that computes the vector-Jacobian product (vjp) given some `cotangents` Tensors."
134132
]
135133
},
136134
{
137135
"cell_type": "code",
138-
"execution_count": 4,
136+
"execution_count": null,
139137
"id": "ad48f9d4",
140138
"metadata": {},
141139
"outputs": [],
@@ -154,14 +152,14 @@
154152
"id": "e0221270",
155153
"metadata": {},
156154
"source": [
157-
"### jvp\n",
155+
"### jvp (Jacobian-vector product)\n",
158156
"\n",
159157
"The `jvp` transforms computes Jacobian-vector-products and is also known as \"forward-mode AD\". It is not a higher-order function unlike most other transforms, but it returns the outputs of `func(inputs)` as well as the jvps."
160158
]
161159
},
162160
{
163161
"cell_type": "code",
164-
"execution_count": 5,
162+
"execution_count": null,
165163
"id": "f3772f43",
166164
"metadata": {},
167165
"outputs": [],
@@ -187,7 +185,7 @@
187185
},
188186
{
189187
"cell_type": "code",
190-
"execution_count": 6,
188+
"execution_count": null,
191189
"id": "20f53be2",
192190
"metadata": {},
193191
"outputs": [],
@@ -209,7 +207,7 @@
209207
},
210208
{
211209
"cell_type": "code",
212-
"execution_count": 7,
210+
"execution_count": null,
213211
"id": "97d6c382",
214212
"metadata": {},
215213
"outputs": [],
@@ -229,7 +227,7 @@
229227
},
230228
{
231229
"cell_type": "code",
232-
"execution_count": 8,
230+
"execution_count": null,
233231
"id": "a8c1dedb",
234232
"metadata": {},
235233
"outputs": [],
@@ -251,7 +249,7 @@
251249
},
252250
{
253251
"cell_type": "code",
254-
"execution_count": 9,
252+
"execution_count": null,
255253
"id": "1e511139",
256254
"metadata": {},
257255
"outputs": [],
@@ -274,7 +272,7 @@
274272
},
275273
{
276274
"cell_type": "code",
277-
"execution_count": 10,
275+
"execution_count": null,
278276
"id": "fd1765df",
279277
"metadata": {},
280278
"outputs": [],
@@ -315,7 +313,7 @@
315313
"name": "python",
316314
"nbconvert_exporter": "python",
317315
"pygments_lexer": "ipython3",
318-
"version": "3.9.4"
316+
"version": "3.7.4"
319317
}
320318
},
321319
"nbformat": 4,

0 commit comments

Comments
 (0)