Skip to content

Commit 40de723

Browse files
committed
beef up make_functional docs
1 parent 0ba35ba commit 40de723

File tree

2 files changed

+120
-50
lines changed

2 files changed

+120
-50
lines changed

docs/source/functorch.rst

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ functorch API Reference
33

44
.. currentmodule:: functorch
55

6+
Function Transforms
7+
-------------------
68
.. autosummary::
79
:toctree: generated
810
:nosignatures:
@@ -11,3 +13,37 @@ functorch API Reference
1113
grad_and_value
1214
jacrev
1315
vmap
16+
vjp
17+
18+
Utilities for working with torch.nn.Modules
19+
-------------------------------------------
20+
21+
In general, you can transform over a function that calls a `torch.nn.Module`.
22+
For example, the following is an example of computing a jacobian of a function
23+
that takes three values and returns three values:
24+
25+
.. code-block:: python
26+
27+
model = torch.nn.Linear(3, 3)
28+
29+
def f(x):
30+
return model(x)
31+
32+
x = torch.randn(3)
33+
jacobian = jacrev(f)(x)
34+
assert jacobian.shape == (3, 3)
35+
36+
However, if you want to do something like compute a jacobian over the parameters
37+
of the model, then there needs to be a way to construct a function where the
38+
parameters are the inputs to the function.
39+
That's what :func:`make_functional` and :func:`make_functional_with_buffer` are for:
40+
given a `torch.nn.Module`, these return a new function that accepts `parameters`
41+
and the inputs to the Module's forward pass.
42+
43+
.. autosummary::
44+
:toctree: generated
45+
:nosignatures:
46+
47+
make_functional
48+
make_functional_with_buffers
49+
combine_state_for_ensemble

functorch/_src/make_functional.py

Lines changed: 84 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,9 @@ def make_split_names(lst):
199199
return [name.split('.') for name in lst]
200200

201201
class FunctionalModuleWithBuffers(nn.Module):
202+
"""
203+
This is the callable object returned by :func:`make_functional_with_buffers`.
204+
"""
202205
def __init__(self, stateless_model, param_names, buffer_names):
203206
super(FunctionalModuleWithBuffers, self).__init__()
204207
self.stateless_model = stateless_model
@@ -231,6 +234,9 @@ def forward(self, params, buffers, *args, **kwargs):
231234
_swap_state(self.stateless_model, self.split_names, old_state)
232235

233236
class FunctionalModule(nn.Module):
237+
"""
238+
This is the callable object returned by :func:`make_functional`.
239+
"""
234240
def __init__(self, stateless_model, param_names):
235241
super(FunctionalModule, self).__init__()
236242
self.stateless_model = stateless_model
@@ -254,42 +260,48 @@ def forward(self, params, *args, **kwargs):
254260
_swap_state(self.stateless_model, self.split_names, old_state)
255261

256262
def make_functional(model: nn.Module):
257-
"""make_functional(model) -> func, weights
263+
"""make_functional(model) -> func, params
258264
259-
Given an nn.Module, make_functional extracts the state (weights)
265+
Given a torch.nn.Module, make_functional extracts the state (params)
260266
and returns a functional version of the model, `func`. This makes
261267
it so that it is possible use transforms over the parameters of
262268
`model`.
263269
264270
`func` can be invoked as follows:
265-
```
266-
import torch
267-
import torch.nn as nn
268-
from functorch import make_functional
269271
270-
x = torch.randn(4, 3)
271-
model = nn.Linear(3, 3)
272-
func, params = make_functional(model)
273-
func(params, x)
274-
```
272+
.. code-block:: python
275273
276-
And here is an example of applying the grad transform:
277-
```
278-
import torch
279-
import torch.nn as nn
280-
from functorch import make_functional, grad
274+
import torch
275+
import torch.nn as nn
276+
from functorch import make_functional
281277
282-
x = torch.randn(4, 3)
283-
t = torch.randn(4, 3)
284-
model = nn.Linear(3, 3)
285-
func, params = make_functional(model)
278+
x = torch.randn(4, 3)
279+
model = nn.Linear(3, 3)
280+
func, params = make_functional(model)
281+
func(params, x)
286282
287-
def compute_loss(params, x, t):
288-
y = func(params, x)
289-
return nn.functional.mse_loss(y, t)
283+
And here is an example of applying the grad transform over the parameters
284+
of a model.
285+
286+
.. code-block:: python
287+
288+
import torch
289+
import torch.nn as nn
290+
from functorch import make_functional, grad
291+
292+
x = torch.randn(4, 3)
293+
t = torch.randn(4, 3)
294+
model = nn.Linear(3, 3)
295+
func, params = make_functional(model)
296+
297+
def compute_loss(params, x, t):
298+
y = func(params, x)
299+
return nn.functional.mse_loss(y, t)
300+
301+
grad_weights = grad(compute_loss)(params, x, t)
302+
303+
If the model has any buffers, please use :func:`make_functional_with_buffers` instead.
290304
291-
grad_weights = grad(compute_loss)(params, x, t)
292-
```
293305
"""
294306
buffers = list(model.buffers())
295307
if len(buffers) > 0:
@@ -301,39 +313,43 @@ def compute_loss(params, x, t):
301313
def make_functional_with_buffers(model: nn.Module):
302314
"""make_functional_with_buffers(model) -> func, params, buffers
303315
304-
Given an nn.Module, make_functional_with_buffers extracts the state
316+
Given a torch.nn.Module, make_functional_with_buffers extracts the state
305317
(params and buffers) and returns a functional version of the model `func`
306318
that can be invoked like a function.
307319
308320
`func` can be invoked as follows:
309-
```
310-
import torch
311-
import torch.nn as nn
312-
from functorch import make_functional_with_buffers
313321
314-
x = torch.randn(4, 3)
315-
model = nn.Linear(3, 3)
316-
func, params, buffers = make_functional_with_buffers(model)
317-
func(params, buffers, x)
318-
```
322+
.. code-block:: python
319323
320-
And here is an example of applying the grad transform:
321-
```
322-
import torch
323-
import torch.nn as nn
324-
from functorch import make_functional_with_buffers, grad
324+
import torch
325+
import torch.nn as nn
326+
from functorch import make_functional_with_buffers
325327
326-
x = torch.randn(4, 3)
327-
t = torch.randn(4, 3)
328-
model = nn.Linear(3, 3)
329-
func, params, buffers = make_functional_with_buffers(model)
328+
x = torch.randn(4, 3)
329+
model = nn.Linear(3, 3)
330+
func, params, buffers = make_functional_with_buffers(model)
331+
func(params, buffers, x)
330332
331-
def compute_loss(params, buffers, x, t):
332-
y = func(params, buffers, x)
333-
return nn.functional.mse_loss(y, t)
333+
And here is an example of applying the grad transform over the parameters
334+
of a model:
335+
336+
.. code-block:: python
337+
338+
import torch
339+
import torch.nn as nn
340+
from functorch import make_functional_with_buffers, grad
341+
342+
x = torch.randn(4, 3)
343+
t = torch.randn(4, 3)
344+
model = nn.Linear(3, 3)
345+
func, params, buffers = make_functional_with_buffers(model)
346+
347+
def compute_loss(params, buffers, x, t):
348+
y = func(params, buffers, x)
349+
return nn.functional.mse_loss(y, t)
350+
351+
grad_weights = grad(compute_loss)(params, buffers, x, t)
334352
335-
grad_weights = grad(compute_loss)(params, buffers, x, t)
336-
```
337353
"""
338354
return FunctionalModuleWithBuffers._create_from(model)
339355

@@ -347,6 +363,8 @@ def transpose_stack(tuple_of_tuple_of_tensors):
347363
def combine_state_for_ensemble(models):
348364
"""combine_state_for_ensemble(models) -> func, params, buffers
349365
366+
Prepares a list of torch.nn.Modules for ensembling with :func:`vmap`.
367+
350368
Given a list of `M` nn.Modules of the same class, stacks all of their
351369
parameters and buffers together to make `params` and `buffers`.
352370
Each parameter and buffer in the result will have an additional dimension
@@ -355,7 +373,23 @@ def combine_state_for_ensemble(models):
355373
`combine_state_for_ensemble` also returns `func`, a functional version
356374
of one of the models in `models`. One cannot directly run
357375
`func(params, buffers, *args, **kwargs)` directly, you probably want to
358-
use vmap(func, ...)(params, buffers, *args, **kwargs)
376+
use `vmap(func, ...)(params, buffers, *args, **kwargs)`
377+
378+
Here's an example of how to ensemble over a very simple model:
379+
380+
.. code-block:: python
381+
382+
num_models = 5
383+
batch_size = 64
384+
in_features, out_features = 3, 3
385+
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
386+
data = torch.randn(batch_size, 3)
387+
388+
fmodel, params, buffers = combine_state_for_ensemble(models)
389+
output = vmap(fmodel, (0, 0, None))(params, buffers, data)
390+
391+
assert output.shape == (num_models, batch_size, out_features)
392+
359393
"""
360394
funcs, params, buffers = zip(*[make_functional_with_buffers(model)
361395
for model in models])

0 commit comments

Comments
 (0)