Skip to content

Commit 6aae58c

Browse files
authored
feat(autogram): Add support for non-Tensor args (#442)
* Fix type hints * Add in_dims param to FunctionalVJP and compute it in Hook * Remove xfail on WithModuleWithStringArg * Add WithModuleWithHybridPyTreeArg * Relax warning about inputs and outputs of Engine
1 parent bfce3ec commit 6aae58c

File tree

5 files changed

+85
-19
lines changed

5 files changed

+85
-19
lines changed

src/torchjd/autogram/_engine.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,8 @@ class Engine:
120120
this, but for example `BatchNorm
121121
<https://docs.pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html>`_ does not (it
122122
computes some average and standard deviation over the elements of the batch).
123-
* Their inputs and outputs can be any PyTree (tensor, tuple or list of tensors, dict of
124-
tensors, or any nesting of those structures), but each of these tensors must be batched on
125-
its first dimension. `Transformers
123+
* Their inputs and outputs can be anything, but each input tensor and each output tensor
124+
must be batched on its first dimension. `Transformers
126125
<https://docs.pytorch.org/docs/stable/generated/torch.nn.Transformer.html>`_ and `RNNs
127126
<https://docs.pytorch.org/docs/stable/generated/torch.nn.RNN.html>`_ are thus not
128127
supported yet. This is only an implementation issue, so it should be fixed soon (please

src/torchjd/autogram/_module_hook_manager.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def __init__(
101101
self.gramian_accumulator = gramian_accumulator
102102
self.has_batch_dim = has_batch_dim
103103

104-
def __call__(self, module: nn.Module, args: PyTree, outputs: PyTree) -> PyTree:
104+
def __call__(self, module: nn.Module, args: tuple[PyTree, ...], outputs: PyTree) -> PyTree:
105105
if self.gramian_accumulation_phase:
106106
return outputs
107107

@@ -129,7 +129,14 @@ def __call__(self, module: nn.Module, args: PyTree, outputs: PyTree) -> PyTree:
129129
index = cast(int, preference.argmin().item())
130130
self.target_edges.register(get_gradient_edge(rg_outputs[index]))
131131

132-
vjp = FunctionalVJP(module) if self.has_batch_dim else AutogradVJP(module, rg_outputs)
132+
vjp: VJP
133+
if self.has_batch_dim:
134+
rg_outputs_in_dims = (0,) * len(rg_outputs)
135+
args_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, args)
136+
in_dims = (rg_outputs_in_dims, args_in_dims)
137+
vjp = FunctionalVJP(module, in_dims)
138+
else:
139+
vjp = AutogradVJP(module, rg_outputs)
133140

134141
autograd_fn_rg_outputs = JacobianAccumulator.apply(
135142
self.gramian_accumulation_phase,
@@ -161,15 +168,15 @@ class JacobianAccumulator(torch.autograd.Function):
161168
def forward(
162169
gramian_accumulation_phase: BoolRef,
163170
vjp: VJP,
164-
args: PyTree,
171+
args: tuple[PyTree, ...],
165172
gramian_accumulator: GramianAccumulator,
166173
module: nn.Module,
167174
*rg_tensors: Tensor,
168175
) -> tuple[Tensor, ...]:
169176
return tuple(t.detach() for t in rg_tensors)
170177

171178
# For Python version > 3.10, the type of `inputs` should become
172-
# tuple[BoolRef, VJP, PyTree, GramianAccumulator, nn.Module, *tuple[Tensor, ...]]
179+
# tuple[BoolRef, VJP, tuple[PyTree, ...], GramianAccumulator, nn.Module, *tuple[Tensor, ...]]
173180
@staticmethod
174181
def setup_context(
175182
ctx,
@@ -183,7 +190,9 @@ def setup_context(
183190
ctx.module = inputs[4]
184191

185192
@staticmethod
186-
def backward(ctx, *grad_outputs: Tensor):
193+
def backward(ctx, *grad_outputs: Tensor) -> tuple:
194+
# Return type for python > 3.10: # tuple[None, None, None, None, None, *tuple[Tensor, ...]]
195+
187196
if not ctx.gramian_accumulation_phase:
188197
return None, None, None, None, None, *grad_outputs
189198

@@ -203,7 +212,7 @@ class AccumulateJacobian(torch.autograd.Function):
203212
@staticmethod
204213
def forward(
205214
vjp: VJP,
206-
args: PyTree,
215+
args: tuple[PyTree, ...],
207216
gramian_accumulator: GramianAccumulator,
208217
module: nn.Module,
209218
*grad_outputs: Tensor,
@@ -216,9 +225,9 @@ def forward(
216225
@staticmethod
217226
def vmap(
218227
_,
219-
in_dims: PyTree,
228+
in_dims: tuple, # tuple[None, tuple[PyTree, ...], None, None, *tuple[int | None, ...]]
220229
vjp: VJP,
221-
args: PyTree,
230+
args: tuple[PyTree, ...],
222231
gramian_accumulator: GramianAccumulator,
223232
module: nn.Module,
224233
*jac_outputs: Tensor,
@@ -244,5 +253,5 @@ def _make_path_jacobians(
244253
return path_jacobians
245254

246255
@staticmethod
247-
def setup_context(*_):
256+
def setup_context(*_) -> None:
248257
pass

src/torchjd/autogram/_vjp.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ class VJP(ABC):
1919
"""Represents an abstract VJP function."""
2020

2121
@abstractmethod
22-
def __call__(self, grad_outputs: tuple[Tensor, ...], args: PyTree) -> dict[str, Tensor]:
22+
def __call__(
23+
self, grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...]
24+
) -> dict[str, Tensor]:
2325
"""
2426
Computes and returns the dictionary of parameter names to their gradients for the given
2527
grad_outputs (cotangents) and at the given inputs.
@@ -52,15 +54,17 @@ class FunctionalVJP(ModuleVJP):
5254
every module, and it requires to have an extra forward pass to create the vjp function.
5355
"""
5456

55-
def __init__(self, module: nn.Module):
57+
def __init__(self, module: nn.Module, in_dims: tuple[PyTree, ...]):
5658
super().__init__(module)
57-
self.vmapped_vjp = torch.vmap(self._call_on_one_instance)
59+
self.vmapped_vjp = torch.vmap(self._call_on_one_instance, in_dims=in_dims)
5860

59-
def __call__(self, grad_outputs: tuple[Tensor, ...], args: PyTree) -> dict[str, Tensor]:
61+
def __call__(
62+
self, grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...]
63+
) -> dict[str, Tensor]:
6064
return self.vmapped_vjp(grad_outputs, args)
6165

6266
def _call_on_one_instance(
63-
self, grad_outputs_j: tuple[Tensor, ...], args_j: PyTree
67+
self, grad_outputs_j: tuple[Tensor, ...], args_j: tuple[PyTree, ...]
6468
) -> dict[str, Tensor]:
6569
# Note: we use unsqueeze(0) to turn a single activation (or grad_output) into a
6670
# "batch" of 1 activation (or grad_output). This is because some layers (e.g.
@@ -103,7 +107,9 @@ def __init__(self, module: nn.Module, rg_outputs: Sequence[Tensor]):
103107
self.rg_outputs = rg_outputs
104108
self.flat_trainable_params, self.param_spec = tree_flatten(self.trainable_params)
105109

106-
def __call__(self, grad_outputs: tuple[Tensor, ...], _: PyTree) -> dict[str, Tensor]:
110+
def __call__(
111+
self, grad_outputs: tuple[Tensor, ...], _: tuple[PyTree, ...]
112+
) -> dict[str, Tensor]:
107113
grads = torch.autograd.grad(
108114
self.rg_outputs,
109115
self.flat_trainable_params,

tests/unit/autogram/test_engine.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
WithBuffered,
5252
WithDropout,
5353
WithModuleTrackingRunningStats,
54+
WithModuleWithHybridPyTreeArg,
5455
WithModuleWithStringArg,
5556
WithModuleWithStringOutput,
5657
WithNoTensorOutput,
@@ -109,6 +110,8 @@
109110
(Ndim3Output, 32),
110111
(Ndim4Output, 32),
111112
(WithDropout, 32),
113+
(WithModuleWithStringArg, 32),
114+
(WithModuleWithHybridPyTreeArg, 32),
112115
(WithModuleWithStringOutput, 32),
113116
(FreeParam, 32),
114117
(NoFreeParam, 32),
@@ -167,7 +170,6 @@ def test_compute_gramian(architecture: type[ShapedModule], batch_size: int, batc
167170
Randomness,
168171
WithModuleTrackingRunningStats,
169172
param(WithRNN, marks=mark.xfail_if_cuda),
170-
WithModuleWithStringArg,
171173
],
172174
)
173175
@mark.parametrize("batch_size", [1, 3, 32])

tests/utils/architectures.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -797,6 +797,56 @@ def forward(self, input: Tensor) -> Tensor:
797797
return self.with_string_arg("two", input)
798798

799799

800+
class WithModuleWithHybridPyTreeArg(ShapedModule):
801+
"""
802+
Model containing a module that has a PyTree argument containing a mix of tensor and non-tensor
803+
leaves.
804+
"""
805+
806+
INPUT_SHAPES = (10,)
807+
OUTPUT_SHAPES = (3,)
808+
809+
class WithHybridPyTreeArg(nn.Module):
810+
def __init__(self):
811+
super().__init__()
812+
self.m0 = nn.Parameter(torch.randn(3, 3))
813+
self.m1 = nn.Parameter(torch.randn(4, 3))
814+
self.m2 = nn.Parameter(torch.randn(5, 3))
815+
self.m3 = nn.Parameter(torch.randn(6, 3))
816+
817+
def forward(self, input: PyTree) -> Tensor:
818+
t0 = input["one"][0][0]
819+
t1 = input["one"][0][1]
820+
t2 = input["one"][1]
821+
t3 = input["two"]
822+
823+
c0 = input["one"][0][3]
824+
c1 = input["one"][0][4][0]
825+
c2 = input["one"][2]
826+
c3 = input["three"]
827+
828+
return c0 * t0 @ self.m0 + c1 * t1 @ self.m1 + c2 * t2 @ self.m2 + c3 * t3 @ self.m3
829+
830+
def __init__(self):
831+
super().__init__()
832+
self.linear = nn.Linear(10, 18)
833+
self.with_string_arg = self.WithHybridPyTreeArg()
834+
835+
def forward(self, input: Tensor) -> Tensor:
836+
input = self.linear(input)
837+
838+
t0, t1, t2, t3 = input[:, 0:3], input[:, 3:7], input[:, 7:12], input[:, 12:18]
839+
840+
tree = {
841+
"zero": "unused",
842+
"one": [(t0, t1, "unused", 0.2, [0.3, "unused"]), t2, 0.4, "unused"],
843+
"two": t3,
844+
"three": 0.5,
845+
}
846+
847+
return self.with_string_arg(tree)
848+
849+
800850
class WithModuleWithStringOutput(ShapedModule):
801851
"""Model containing a module that has a string output."""
802852

0 commit comments

Comments
 (0)