@@ -199,6 +199,9 @@ def make_split_names(lst):
199
199
return [name .split ('.' ) for name in lst ]
200
200
201
201
class FunctionalModuleWithBuffers (nn .Module ):
202
+ """
203
+ This is the callable object returned by :func:`make_functional_with_buffers`.
204
+ """
202
205
def __init__ (self , stateless_model , param_names , buffer_names ):
203
206
super (FunctionalModuleWithBuffers , self ).__init__ ()
204
207
self .stateless_model = stateless_model
@@ -231,6 +234,9 @@ def forward(self, params, buffers, *args, **kwargs):
231
234
_swap_state (self .stateless_model , self .split_names , old_state )
232
235
233
236
class FunctionalModule (nn .Module ):
237
+ """
238
+ This is the callable object returned by :func:`make_functional`.
239
+ """
234
240
def __init__ (self , stateless_model , param_names ):
235
241
super (FunctionalModule , self ).__init__ ()
236
242
self .stateless_model = stateless_model
@@ -254,42 +260,48 @@ def forward(self, params, *args, **kwargs):
254
260
_swap_state (self .stateless_model , self .split_names , old_state )
255
261
256
262
def make_functional (model : nn .Module ):
257
- """make_functional(model) -> func, weights
263
+ """make_functional(model) -> func, params
258
264
259
- Given an nn.Module, make_functional extracts the state (weights )
265
+ Given a torch. nn.Module, make_functional extracts the state (params )
260
266
and returns a functional version of the model, `func`. This makes
261
267
it so that it is possible use transforms over the parameters of
262
268
`model`.
263
269
264
270
`func` can be invoked as follows:
265
- ```
266
- import torch
267
- import torch.nn as nn
268
- from functorch import make_functional
269
271
270
- x = torch.randn(4, 3)
271
- model = nn.Linear(3, 3)
272
- func, params = make_functional(model)
273
- func(params, x)
274
- ```
272
+ .. code-block:: python
275
273
276
- And here is an example of applying the grad transform:
277
- ```
278
- import torch
279
- import torch.nn as nn
280
- from functorch import make_functional, grad
274
+ import torch
275
+ import torch.nn as nn
276
+ from functorch import make_functional
281
277
282
- x = torch.randn(4, 3)
283
- t = torch.randn(4 , 3)
284
- model = nn.Linear(3, 3 )
285
- func, params = make_functional(model )
278
+ x = torch.randn(4, 3)
279
+ model = nn.Linear(3 , 3)
280
+ func, params = make_functional(model )
281
+ func(params, x )
286
282
287
- def compute_loss(params, x, t):
288
- y = func(params, x)
289
- return nn.functional.mse_loss(y, t)
283
+ And here is an example of applying the grad transform over the parameters
284
+ of a model.
285
+
286
+ .. code-block:: python
287
+
288
+ import torch
289
+ import torch.nn as nn
290
+ from functorch import make_functional, grad
291
+
292
+ x = torch.randn(4, 3)
293
+ t = torch.randn(4, 3)
294
+ model = nn.Linear(3, 3)
295
+ func, params = make_functional(model)
296
+
297
+ def compute_loss(params, x, t):
298
+ y = func(params, x)
299
+ return nn.functional.mse_loss(y, t)
300
+
301
+ grad_weights = grad(compute_loss)(params, x, t)
302
+
303
+ If the model has any buffers, please use :func:`make_functional_with_buffers` instead.
290
304
291
- grad_weights = grad(compute_loss)(params, x, t)
292
- ```
293
305
"""
294
306
buffers = list (model .buffers ())
295
307
if len (buffers ) > 0 :
@@ -301,39 +313,43 @@ def compute_loss(params, x, t):
301
313
def make_functional_with_buffers (model : nn .Module ):
302
314
"""make_functional_with_buffers(model) -> func, params, buffers
303
315
304
- Given an nn.Module, make_functional_with_buffers extracts the state
316
+ Given a torch. nn.Module, make_functional_with_buffers extracts the state
305
317
(params and buffers) and returns a functional version of the model `func`
306
318
that can be invoked like a function.
307
319
308
320
`func` can be invoked as follows:
309
- ```
310
- import torch
311
- import torch.nn as nn
312
- from functorch import make_functional_with_buffers
313
321
314
- x = torch.randn(4, 3)
315
- model = nn.Linear(3, 3)
316
- func, params, buffers = make_functional_with_buffers(model)
317
- func(params, buffers, x)
318
- ```
322
+ .. code-block:: python
319
323
320
- And here is an example of applying the grad transform:
321
- ```
322
- import torch
323
- import torch.nn as nn
324
- from functorch import make_functional_with_buffers, grad
324
+ import torch
325
+ import torch.nn as nn
326
+ from functorch import make_functional_with_buffers
325
327
326
- x = torch.randn(4, 3)
327
- t = torch.randn(4 , 3)
328
- model = nn.Linear(3, 3 )
329
- func, params, buffers = make_functional_with_buffers(model )
328
+ x = torch.randn(4, 3)
329
+ model = nn.Linear(3 , 3)
330
+ func, params, buffers = make_functional_with_buffers(model )
331
+ func( params, buffers, x )
330
332
331
- def compute_loss(params, buffers, x, t):
332
- y = func(params, buffers, x)
333
- return nn.functional.mse_loss(y, t)
333
+ And here is an example of applying the grad transform over the parameters
334
+ of a model:
335
+
336
+ .. code-block:: python
337
+
338
+ import torch
339
+ import torch.nn as nn
340
+ from functorch import make_functional_with_buffers, grad
341
+
342
+ x = torch.randn(4, 3)
343
+ t = torch.randn(4, 3)
344
+ model = nn.Linear(3, 3)
345
+ func, params, buffers = make_functional_with_buffers(model)
346
+
347
+ def compute_loss(params, buffers, x, t):
348
+ y = func(params, buffers, x)
349
+ return nn.functional.mse_loss(y, t)
350
+
351
+ grad_weights = grad(compute_loss)(params, buffers, x, t)
334
352
335
- grad_weights = grad(compute_loss)(params, buffers, x, t)
336
- ```
337
353
"""
338
354
return FunctionalModuleWithBuffers ._create_from (model )
339
355
@@ -347,6 +363,8 @@ def transpose_stack(tuple_of_tuple_of_tensors):
347
363
def combine_state_for_ensemble (models ):
348
364
"""combine_state_for_ensemble(models) -> func, params, buffers
349
365
366
+ Prepares a list of torch.nn.Modules for ensembling with :func:`vmap`.
367
+
350
368
Given a list of `M` nn.Modules of the same class, stacks all of their
351
369
parameters and buffers together to make `params` and `buffers`.
352
370
Each parameter and buffer in the result will have an additional dimension
@@ -355,7 +373,23 @@ def combine_state_for_ensemble(models):
355
373
`combine_state_for_ensemble` also returns `func`, a functional version
356
374
of one of the models in `models`. One cannot directly run
357
375
`func(params, buffers, *args, **kwargs)` directly, you probably want to
358
- use vmap(func, ...)(params, buffers, *args, **kwargs)
376
+ use `vmap(func, ...)(params, buffers, *args, **kwargs)`
377
+
378
+ Here's an example of how to ensemble over a very simple model:
379
+
380
+ .. code-block:: python
381
+
382
+ num_models = 5
383
+ batch_size = 64
384
+ in_features, out_features = 3, 3
385
+ models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
386
+ data = torch.randn(batch_size, 3)
387
+
388
+ fmodel, params, buffers = combine_state_for_ensemble(models)
389
+ output = vmap(fmodel, (0, 0, None))(params, buffers, data)
390
+
391
+ assert output.shape == (num_models, batch_size, out_features)
392
+
359
393
"""
360
394
funcs , params , buffers = zip (* [make_functional_with_buffers (model )
361
395
for model in models ])
0 commit comments