@@ -19,7 +19,7 @@ class VJP(ABC):
1919 """Represents an abstract VJP function."""
2020
2121 @abstractmethod
22- def __call__ (self , grad_outputs : PyTree , inputs : PyTree ) -> dict [str , Tensor ]:
22+ def __call__ (self , grad_outputs : PyTree , args : PyTree ) -> dict [str , Tensor ]:
2323 """
2424 Computes and returns the dictionary of parameter names to their gradients for the given
2525 grad_outputs (cotangents) and at the given inputs.
@@ -56,25 +56,25 @@ def __init__(self, module: nn.Module):
5656 super ().__init__ (module )
5757 self .vmapped_vjp = torch .vmap (self ._call_on_one_instance )
5858
59- def __call__ (self , grad_outputs : PyTree , inputs : PyTree ) -> dict [str , Tensor ]:
60- return self .vmapped_vjp (grad_outputs , inputs )
59+ def __call__ (self , grad_outputs : PyTree , args : PyTree ) -> dict [str , Tensor ]:
60+ return self .vmapped_vjp (grad_outputs , args )
6161
62- def _call_on_one_instance (self , grad_outputs_j : PyTree , inputs_j : PyTree ) -> dict [str , Tensor ]:
62+ def _call_on_one_instance (self , grad_outputs_j : PyTree , args_j : PyTree ) -> dict [str , Tensor ]:
6363 # Note: we use unsqueeze(0) to turn a single activation (or grad_output) into a
6464 # "batch" of 1 activation (or grad_output). This is because some layers (e.g.
6565 # nn.Flatten) do not work equivalently if they're provided with a batch or with
6666 # an element of a batch. We thus always provide them with batches, just of a
6767 # different size.
68- inputs_j = tree_map_only (torch .Tensor , lambda x : x .unsqueeze (0 ), inputs_j )
68+ args_j = tree_map_only (torch .Tensor , lambda x : x .unsqueeze (0 ), args_j )
6969 grad_outputs_j = tree_map_only (torch .Tensor , lambda x : x .unsqueeze (0 ), grad_outputs_j )
7070
71- def functional_model_call (primals : dict [str , Parameter ]) -> Tensor :
71+ def functional_model_call (trainable_params : dict [str , Parameter ]) -> Tensor :
7272 all_state = {
73- ** primals ,
73+ ** trainable_params ,
7474 ** dict (self .module .named_buffers ()),
7575 ** self .frozen_params ,
7676 }
77- return torch .func .functional_call (self .module , all_state , inputs_j )
77+ return torch .func .functional_call (self .module , all_state , args_j )
7878
7979 vjp_func = torch .func .vjp (functional_model_call , self .trainable_params )[1 ]
8080
0 commit comments