Skip to content

Commit b542e9f

Browse files
authored
refactor(autogram): Rename internal variables in VJP (#435)
* Rename primals to trainable_params in functional_model_call * Rename inputs to args
1 parent e44f434 commit b542e9f

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

src/torchjd/autogram/_vjp.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class VJP(ABC):
1919
"""Represents an abstract VJP function."""
2020

2121
@abstractmethod
22-
def __call__(self, grad_outputs: PyTree, inputs: PyTree) -> dict[str, Tensor]:
22+
def __call__(self, grad_outputs: PyTree, args: PyTree) -> dict[str, Tensor]:
2323
"""
2424
Computes and returns the dictionary of parameter names to their gradients for the given
2525
grad_outputs (cotangents) and at the given inputs.
@@ -56,25 +56,25 @@ def __init__(self, module: nn.Module):
5656
super().__init__(module)
5757
self.vmapped_vjp = torch.vmap(self._call_on_one_instance)
5858

59-
def __call__(self, grad_outputs: PyTree, inputs: PyTree) -> dict[str, Tensor]:
60-
return self.vmapped_vjp(grad_outputs, inputs)
59+
def __call__(self, grad_outputs: PyTree, args: PyTree) -> dict[str, Tensor]:
60+
return self.vmapped_vjp(grad_outputs, args)
6161

62-
def _call_on_one_instance(self, grad_outputs_j: PyTree, inputs_j: PyTree) -> dict[str, Tensor]:
62+
def _call_on_one_instance(self, grad_outputs_j: PyTree, args_j: PyTree) -> dict[str, Tensor]:
6363
# Note: we use unsqueeze(0) to turn a single activation (or grad_output) into a
6464
# "batch" of 1 activation (or grad_output). This is because some layers (e.g.
6565
# nn.Flatten) do not work equivalently if they're provided with a batch or with
6666
# an element of a batch. We thus always provide them with batches, just of a
6767
# different size.
68-
inputs_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), inputs_j)
68+
args_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), args_j)
6969
grad_outputs_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), grad_outputs_j)
7070

71-
def functional_model_call(primals: dict[str, Parameter]) -> Tensor:
71+
def functional_model_call(trainable_params: dict[str, Parameter]) -> Tensor:
7272
all_state = {
73-
**primals,
73+
**trainable_params,
7474
**dict(self.module.named_buffers()),
7575
**self.frozen_params,
7676
}
77-
return torch.func.functional_call(self.module, all_state, inputs_j)
77+
return torch.func.functional_call(self.module, all_state, args_j)
7878

7979
vjp_func = torch.func.vjp(functional_model_call, self.trainable_params)[1]
8080

0 commit comments

Comments
 (0)