Skip to content

Commit 832b242

Browse files
authored
Merge branch 'main' into autogram-readme
2 parents 3a442b2 + 19d375d commit 832b242

File tree

11 files changed

+48
-90
lines changed

11 files changed

+48
-90
lines changed

docs/source/examples/iwmtl.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ The following example shows how to do that.
3131
optimizer = SGD(params, lr=0.1)
3232
mse = MSELoss(reduction="none")
3333
weighting = Flattening(UPGradWeighting())
34-
engine = Engine(shared_module.modules(), batch_dim=0)
34+
engine = Engine(shared_module, batch_dim=0)
3535
3636
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
3737
task1_targets = torch.randn(8, 16) # 8 batches of 16 targets for the first task

docs/source/examples/iwrm.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
129129
params = model.parameters()
130130
optimizer = SGD(params, lr=0.1)
131131
weighting = UPGradWeighting()
132-
engine = Engine(model.modules(), batch_dim=0)
132+
engine = Engine(model, batch_dim=0)
133133
134134
for x, y in zip(X, Y):
135135
y_hat = model(x).squeeze(dim=1) # shape: [16]

docs/source/examples/partial_jd.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ first ``Linear`` layer, thereby reducing memory usage and computation time.
3333
3434
# Create the autogram engine that will compute the Gramian of the
3535
# Jacobian with respect to the two last Linear layers' parameters.
36-
engine = Engine(model[2:].modules(), batch_dim=0)
36+
engine = Engine(model[2:], batch_dim=0)
3737
3838
params = model.parameters()
3939
optimizer = SGD(params, lr=0.1)

src/torchjd/autogram/_engine.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from collections.abc import Iterable
21
from typing import cast
32

43
import torch
@@ -58,8 +57,9 @@ class Engine:
5857
backpropagate the losses. This is equivalent to doing a step of standard Jacobian descent using
5958
:func:`torchjd.autojac.backward`.
6059
61-
:param modules: A collection of modules whose direct (non-recursive) parameters will contribute
62-
to the Gramian of the Jacobian.
60+
:param modules: The modules whose parameters will contribute to the Gramian of the Jacobian.
61+
Several modules can be provided, but it's important that none of them is a child module of
62+
another of them.
6363
:param batch_dim: If the modules work with batches and process each batch element independently,
6464
then many intermediary Jacobians are sparse (block-diagonal), which allows for a substantial
6565
memory optimization by backpropagating a squashed Jacobian instead. This parameter indicates
@@ -91,7 +91,7 @@ class Engine:
9191
weighting = UPGradWeighting()
9292
9393
# Create the engine before the backward pass, and only once.
94-
engine = Engine(model.modules(), batch_dim=0)
94+
engine = Engine(model, batch_dim=0)
9595
9696
for input, target in zip(inputs, targets):
9797
output = model(input).squeeze(dim=1) # shape: [16]
@@ -173,7 +173,7 @@ class Engine:
173173

174174
def __init__(
175175
self,
176-
modules: Iterable[nn.Module],
176+
*modules: nn.Module,
177177
batch_dim: int | None,
178178
):
179179
self._gramian_accumulator = GramianAccumulator()
@@ -183,16 +183,16 @@ def __init__(
183183
self._target_edges, self._gramian_accumulator, batch_dim is not None
184184
)
185185

186-
self._hook_modules(modules)
186+
for module in modules:
187+
self._hook_module_recursively(module)
187188

188-
def _hook_modules(self, modules: Iterable[nn.Module]) -> None:
189-
_modules = set(modules)
190-
191-
# Add module forward hooks to compute jacobians
192-
for module in _modules:
193-
if any(p.requires_grad for p in module.parameters(recurse=False)):
194-
self._check_module_is_compatible(module)
195-
self._module_hook_manager.hook_module(module)
189+
def _hook_module_recursively(self, module: nn.Module) -> None:
190+
if any(p.requires_grad for p in module.parameters(recurse=False)):
191+
self._check_module_is_compatible(module)
192+
self._module_hook_manager.hook_module(module)
193+
else:
194+
for child in module.children():
195+
self._hook_module_recursively(child)
196196

197197
def _check_module_is_compatible(self, module: nn.Module) -> None:
198198
if self._batch_dim is not None:

src/torchjd/autogram/_module_hook_manager.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
from ._edge_registry import EdgeRegistry
1111
from ._gramian_accumulator import GramianAccumulator
12-
from ._module_utils import get_used_params
1312
from ._vjp import VJP, AutogradVJP, FunctionalVJP
1413

1514
# Note about import from protected _pytree module:
@@ -126,8 +125,8 @@ def __call__(
126125
# require grad
127126
return outputs
128127

129-
rg_params, _ = get_used_params(module)
130-
self.gramian_accumulator.track_parameter_paths(rg_params.values())
128+
rg_params = [p for p in module.parameters(recurse=True) if p.requires_grad]
129+
self.gramian_accumulator.track_parameter_paths(rg_params)
131130

132131
# We only care about running the JacobianAccumulator node, so we need one of its child
133132
# edges (the edges of the original outputs of the model) as target. For memory

src/torchjd/autogram/_module_utils.py

Lines changed: 0 additions & 47 deletions
This file was deleted.

src/torchjd/autogram/_vjp.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
from torch.nn import Parameter
77
from torch.utils._pytree import PyTree, tree_flatten, tree_map_only, tree_unflatten
88

9-
from torchjd.autogram._module_utils import get_used_params
10-
119
# Note about import from protected _pytree module:
1210
# PyTorch maintainers plan to make pytree public (see
1311
# https://github.com/pytorch/pytorch/issues/65761, https://github.com/pytorch/pytorch/pull/137400).
@@ -39,7 +37,15 @@ class ModuleVJP(VJP, ABC):
3937

4038
def __init__(self, module: nn.Module):
4139
self.module = module
42-
self.rg_params, self.frozen_params = get_used_params(module)
40+
41+
self.rg_params = dict[str, Parameter]()
42+
self.frozen_params = dict[str, Parameter]()
43+
44+
for name, param in module.named_parameters(recurse=True):
45+
if param.requires_grad:
46+
self.rg_params[name] = param
47+
else:
48+
self.frozen_params[name] = param
4349

4450

4551
class FunctionalVJP(ModuleVJP):

tests/doc/test_autogram.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def test_engine():
2020
weighting = UPGradWeighting()
2121

2222
# Create the engine before the backward pass, and only once.
23-
engine = Engine(model.modules(), batch_dim=0)
23+
engine = Engine(model, batch_dim=0)
2424

2525
for input, target in zip(inputs, targets):
2626
output = model(input).squeeze(dim=1) # shape: [16]

tests/doc/test_rst.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def test_iwmtl():
9494
optimizer = SGD(params, lr=0.1)
9595
mse = MSELoss(reduction="none")
9696
weighting = Flattening(UPGradWeighting())
97-
engine = Engine(shared_module.modules(), batch_dim=0)
97+
engine = Engine(shared_module, batch_dim=0)
9898

9999
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
100100
task1_targets = torch.randn(8, 16) # 8 batches of 16 targets for the first task
@@ -184,7 +184,7 @@ def test_autogram():
184184
params = model.parameters()
185185
optimizer = SGD(params, lr=0.1)
186186
weighting = UPGradWeighting()
187-
engine = Engine(model.modules(), batch_dim=0)
187+
engine = Engine(model, batch_dim=0)
188188

189189
for x, y in zip(X, Y):
190190
y_hat = model(x).squeeze(dim=1) # shape: [16]
@@ -374,7 +374,7 @@ def test_partial_jd():
374374

375375
# Create the autogram engine that will compute the Gramian of the
376376
# Jacobian with respect to the two last Linear layers' parameters.
377-
engine = Engine(model[2:].modules(), batch_dim=0)
377+
engine = Engine(model[2:], batch_dim=0)
378378

379379
params = model.parameters()
380380
optimizer = SGD(params, lr=0.1)

tests/speed/autogram/grad_vs_jac_vs_gram.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def post_fn():
121121
print(autojac_times)
122122
print()
123123

124-
engine = Engine(model.modules(), batch_dim=0)
124+
engine = Engine(model, batch_dim=0)
125125
autogram_times = torch.tensor(time_call(fn_autogram, init_fn_autogram, pre_fn, post_fn, n_runs))
126126
print(f"autogram times (avg = {autogram_times.mean():.5f}, std = {autogram_times.std():.5f}")
127127
print(autogram_times)

0 commit comments

Comments
 (0)