11from abc import ABC , abstractmethod
2- from collections .abc import Callable , Sequence
2+ from collections .abc import Sequence
33
44import torch
55from torch import Tensor , nn
@@ -45,31 +45,21 @@ def __init__(self, module: nn.Module):
4545 self .frozen_params [name ] = param
4646
4747
48- class Vmapped (VJP ):
49- """VJP wrapper that applies the wrapped VJP, vmapped on the first dimension."""
50-
51- def __init__ (self , vjp : VJP ):
52- super ().__init__ ()
53- self .vmapped_vjp = torch .vmap (vjp )
54-
55- def __call__ (self , grad_outputs : PyTree , inputs : PyTree ) -> dict [str , Tensor ]:
56- return self .vmapped_vjp (grad_outputs , inputs )
57-
58-
5948class FunctionalVJP (ModuleVJP ):
6049 """
6150 Represents a VJP function for a module's forward pass with respect to its parameters using the
62- func api. The __call__ function takes both the inputs and the cotangents that can be vmapped
63- jointly in both terms to avoid providing to block diagonal jacobians. The disadvantage of using
64- this method is that it makes an extra forward pass.
65-
66- :params module: The module to differentiate.
51+ functional differentiation API. This requires to use vmap, so it's not compatible with
52+ every module, and it requires to have an extra forward pass to create the vjp function.
6753 """
6854
6955 def __init__ (self , module : nn .Module ):
7056 super ().__init__ (module )
57+ self .vmapped_vjp = torch .vmap (self ._call_on_one_instance )
58+
59+ def __call__ (self , grad_outputs : PyTree , inputs : PyTree ) -> dict [str , Tensor ]:
60+ return self .vmapped_vjp (grad_outputs , inputs )
7161
72- def __call__ (self , grad_outputs_j : PyTree , inputs_j : PyTree ) -> dict [str , Tensor ]:
62+ def _call_on_one_instance (self , grad_outputs_j : PyTree , inputs_j : PyTree ) -> dict [str , Tensor ]:
7363 # Note: we use unsqueeze(0) to turn a single activation (or grad_output) into a
7464 # "batch" of 1 activation (or grad_output). This is because some layers (e.g.
7565 # nn.Flatten) do not work equivalently if they're provided with a batch or with
@@ -78,32 +68,20 @@ def __call__(self, grad_outputs_j: PyTree, inputs_j: PyTree) -> dict[str, Tensor
7868 inputs_j = tree_map_only (torch .Tensor , lambda x : x .unsqueeze (0 ), inputs_j )
7969 grad_outputs_j = tree_map_only (torch .Tensor , lambda x : x .unsqueeze (0 ), grad_outputs_j )
8070
81- # _vjp_from_module returns a function that computes the vjp w.r.t. to the
82- # primals (tuple), here the functional has a single primal which is
83- # dict(module.named_parameters()). We therefore take the 0'th element to obtain
84- # the dict of gradients w.r.t. the module's named_parameters.
85- return self ._vjp_from_module (inputs_j )(grad_outputs_j )[0 ]
86-
87- def _vjp_from_module (self , inputs : PyTree ) -> Callable [[PyTree ], tuple [dict [str , Tensor ]]]:
88- """
89- Create a VJP function for a module's forward pass with respect to its parameters.
90-
91- Returns a function that computes vector-Jacobian products for the module's parameters given
92- fixed inputs. Only parameters with requires_grad=True are included in the differentiation.
93-
94- :param inputs: Fixed inputs to the module for the VJP computation.
95- :returns: VJP function that takes cotangents and returns parameter gradients.
96- """
97-
9871 def functional_model_call (primals : dict [str , Parameter ]) -> Tensor :
9972 all_state = {
10073 ** primals ,
10174 ** dict (self .module .named_buffers ()),
10275 ** self .frozen_params ,
10376 }
104- return torch .func .functional_call (self .module , all_state , inputs )
77+ return torch .func .functional_call (self .module , all_state , inputs_j )
78+
79+ vjp_func = torch .func .vjp (functional_model_call , self .trainable_params )[1 ]
10580
106- return torch .func .vjp (functional_model_call , self .trainable_params )[1 ]
81+ # vjp_func is a function that computes the vjp w.r.t. to the primals (tuple). Here the
82+ # functional has a single primal which is dict(module.named_parameters()). We therefore take
83+ # the 0'th element to obtain the dict of gradients w.r.t. the module's named_parameters.
84+ return vjp_func (grad_outputs_j )[0 ]
10785
10886
10987class AutogradVJP (ModuleVJP ):
0 commit comments