Skip to content

Commit 19d375d

Browse files
authored
refactor(autogram): Make engine hook recursively (#451)
* Make the Engine hook modules recursively. If a direct parameter exists, hook the module and do not hook its children. If not, try to hook the child modules. * This change means that only the parentmost module with direct rg params gets hooked. Its used parameters are thus simply module.parameters(recurse=True) now. This is even the case in special cases where the parent uses the child parameters, so we don't need to have a special case for MHA anymore. * Remove _module_utils.py: it's now trivial to know with respect to which parameters to differentiate. * Update all usages to now create the Engine with Engine(model) instead of Engine(model.modules()). For partial JD, users have to be more careful, as they should sometimes specify several modules, but these modules should be "disjoint" (i.e. no specified module should be a child of another specified module) * This mostly makes a difference on FreeParam. Before, we had 2 hooks (one for the parent, parameterized with the parent's param - aka the free param, and one for the child module, parameterized with the child's params). Now we simply have 1 hook for the parent, parameterized with the all parameters (i.e. parent.parameters(recurse=True)). This is probably faster (because we don't have to do 2 extra forwards and 2 extra backwards for the child, but just 1 now), but maybe a bit more memory consuming (because we have to store the Jacobian wrt the child's params and wrt the parent's free param at the same time). This case is quite niche though, and I still see it as an improvement. * Change Engine to take *modules: nn.Module instead of Iterable[nn.Module] (more convenient for the new usage, because we only specify one model 99% of the time). Update the docstring accordingly.
1 parent 6d0b3a8 commit 19d375d

File tree

11 files changed

+48
-90
lines changed

11 files changed

+48
-90
lines changed

docs/source/examples/iwmtl.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ The following example shows how to do that.
3131
optimizer = SGD(params, lr=0.1)
3232
mse = MSELoss(reduction="none")
3333
weighting = Flattening(UPGradWeighting())
34-
engine = Engine(shared_module.modules(), batch_dim=0)
34+
engine = Engine(shared_module, batch_dim=0)
3535
3636
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
3737
task1_targets = torch.randn(8, 16) # 8 batches of 16 targets for the first task

docs/source/examples/iwrm.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
129129
params = model.parameters()
130130
optimizer = SGD(params, lr=0.1)
131131
weighting = UPGradWeighting()
132-
engine = Engine(model.modules(), batch_dim=0)
132+
engine = Engine(model, batch_dim=0)
133133
134134
for x, y in zip(X, Y):
135135
y_hat = model(x).squeeze(dim=1) # shape: [16]

docs/source/examples/partial_jd.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ first ``Linear`` layer, thereby reducing memory usage and computation time.
3333
3434
# Create the autogram engine that will compute the Gramian of the
3535
# Jacobian with respect to the two last Linear layers' parameters.
36-
engine = Engine(model[2:].modules(), batch_dim=0)
36+
engine = Engine(model[2:], batch_dim=0)
3737
3838
params = model.parameters()
3939
optimizer = SGD(params, lr=0.1)

src/torchjd/autogram/_engine.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from collections.abc import Iterable
21
from typing import cast
32

43
import torch
@@ -58,8 +57,9 @@ class Engine:
5857
backpropagate the losses. This is equivalent to doing a step of standard Jacobian descent using
5958
:func:`torchjd.autojac.backward`.
6059
61-
:param modules: A collection of modules whose direct (non-recursive) parameters will contribute
62-
to the Gramian of the Jacobian.
60+
:param modules: The modules whose parameters will contribute to the Gramian of the Jacobian.
61+
Several modules can be provided, but it's important that none of them is a child module of
62+
another of them.
6363
:param batch_dim: If the modules work with batches and process each batch element independently,
6464
then many intermediary Jacobians are sparse (block-diagonal), which allows for a substantial
6565
memory optimization by backpropagating a squashed Jacobian instead. This parameter indicates
@@ -91,7 +91,7 @@ class Engine:
9191
weighting = UPGradWeighting()
9292
9393
# Create the engine before the backward pass, and only once.
94-
engine = Engine(model.modules(), batch_dim=0)
94+
engine = Engine(model, batch_dim=0)
9595
9696
for input, target in zip(inputs, targets):
9797
output = model(input).squeeze(dim=1) # shape: [16]
@@ -173,7 +173,7 @@ class Engine:
173173

174174
def __init__(
175175
self,
176-
modules: Iterable[nn.Module],
176+
*modules: nn.Module,
177177
batch_dim: int | None,
178178
):
179179
self._gramian_accumulator = GramianAccumulator()
@@ -183,16 +183,16 @@ def __init__(
183183
self._target_edges, self._gramian_accumulator, batch_dim is not None
184184
)
185185

186-
self._hook_modules(modules)
186+
for module in modules:
187+
self._hook_module_recursively(module)
187188

188-
def _hook_modules(self, modules: Iterable[nn.Module]) -> None:
189-
_modules = set(modules)
190-
191-
# Add module forward hooks to compute jacobians
192-
for module in _modules:
193-
if any(p.requires_grad for p in module.parameters(recurse=False)):
194-
self._check_module_is_compatible(module)
195-
self._module_hook_manager.hook_module(module)
189+
def _hook_module_recursively(self, module: nn.Module) -> None:
190+
if any(p.requires_grad for p in module.parameters(recurse=False)):
191+
self._check_module_is_compatible(module)
192+
self._module_hook_manager.hook_module(module)
193+
else:
194+
for child in module.children():
195+
self._hook_module_recursively(child)
196196

197197
def _check_module_is_compatible(self, module: nn.Module) -> None:
198198
if self._batch_dim is not None:

src/torchjd/autogram/_module_hook_manager.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
from ._edge_registry import EdgeRegistry
1111
from ._gramian_accumulator import GramianAccumulator
12-
from ._module_utils import get_used_params
1312
from ._vjp import VJP, AutogradVJP, FunctionalVJP
1413

1514
# Note about import from protected _pytree module:
@@ -126,8 +125,8 @@ def __call__(
126125
# require grad
127126
return outputs
128127

129-
rg_params, _ = get_used_params(module)
130-
self.gramian_accumulator.track_parameter_paths(rg_params.values())
128+
rg_params = [p for p in module.parameters(recurse=True) if p.requires_grad]
129+
self.gramian_accumulator.track_parameter_paths(rg_params)
131130

132131
# We only care about running the JacobianAccumulator node, so we need one of its child
133132
# edges (the edges of the original outputs of the model) as target. For memory

src/torchjd/autogram/_module_utils.py

Lines changed: 0 additions & 47 deletions
This file was deleted.

src/torchjd/autogram/_vjp.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
from torch.nn import Parameter
77
from torch.utils._pytree import PyTree, tree_flatten, tree_map_only, tree_unflatten
88

9-
from torchjd.autogram._module_utils import get_used_params
10-
119
# Note about import from protected _pytree module:
1210
# PyTorch maintainers plan to make pytree public (see
1311
# https://github.com/pytorch/pytorch/issues/65761, https://github.com/pytorch/pytorch/pull/137400).
@@ -39,7 +37,15 @@ class ModuleVJP(VJP, ABC):
3937

4038
def __init__(self, module: nn.Module):
4139
self.module = module
42-
self.rg_params, self.frozen_params = get_used_params(module)
40+
41+
self.rg_params = dict[str, Parameter]()
42+
self.frozen_params = dict[str, Parameter]()
43+
44+
for name, param in module.named_parameters(recurse=True):
45+
if param.requires_grad:
46+
self.rg_params[name] = param
47+
else:
48+
self.frozen_params[name] = param
4349

4450

4551
class FunctionalVJP(ModuleVJP):

tests/doc/test_autogram.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def test_engine():
2020
weighting = UPGradWeighting()
2121

2222
# Create the engine before the backward pass, and only once.
23-
engine = Engine(model.modules(), batch_dim=0)
23+
engine = Engine(model, batch_dim=0)
2424

2525
for input, target in zip(inputs, targets):
2626
output = model(input).squeeze(dim=1) # shape: [16]

tests/doc/test_rst.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def test_iwmtl():
9494
optimizer = SGD(params, lr=0.1)
9595
mse = MSELoss(reduction="none")
9696
weighting = Flattening(UPGradWeighting())
97-
engine = Engine(shared_module.modules(), batch_dim=0)
97+
engine = Engine(shared_module, batch_dim=0)
9898

9999
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
100100
task1_targets = torch.randn(8, 16) # 8 batches of 16 targets for the first task
@@ -184,7 +184,7 @@ def test_autogram():
184184
params = model.parameters()
185185
optimizer = SGD(params, lr=0.1)
186186
weighting = UPGradWeighting()
187-
engine = Engine(model.modules(), batch_dim=0)
187+
engine = Engine(model, batch_dim=0)
188188

189189
for x, y in zip(X, Y):
190190
y_hat = model(x).squeeze(dim=1) # shape: [16]
@@ -374,7 +374,7 @@ def test_partial_jd():
374374

375375
# Create the autogram engine that will compute the Gramian of the
376376
# Jacobian with respect to the two last Linear layers' parameters.
377-
engine = Engine(model[2:].modules(), batch_dim=0)
377+
engine = Engine(model[2:], batch_dim=0)
378378

379379
params = model.parameters()
380380
optimizer = SGD(params, lr=0.1)

tests/speed/autogram/grad_vs_jac_vs_gram.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def post_fn():
121121
print(autojac_times)
122122
print()
123123

124-
engine = Engine(model.modules(), batch_dim=0)
124+
engine = Engine(model, batch_dim=0)
125125
autogram_times = torch.tensor(time_call(fn_autogram, init_fn_autogram, pre_fn, post_fn, n_runs))
126126
print(f"autogram times (avg = {autogram_times.mean():.5f}, std = {autogram_times.std():.5f}")
127127
print(autogram_times)

0 commit comments

Comments
 (0)