1- from collections .abc import Iterable
21from typing import cast
32
43import torch
@@ -58,8 +57,9 @@ class Engine:
5857 backpropagate the losses. This is equivalent to doing a step of standard Jacobian descent using
5958 :func:`torchjd.autojac.backward`.
6059
61- :param modules: A collection of modules whose direct (non-recursive) parameters will contribute
62- to the Gramian of the Jacobian.
60+ :param modules: The modules whose parameters will contribute to the Gramian of the Jacobian.
61+ Several modules can be provided, but it's important that none of them is a child module of
62+ another of them.
6363 :param batch_dim: If the modules work with batches and process each batch element independently,
6464 then many intermediary Jacobians are sparse (block-diagonal), which allows for a substantial
6565 memory optimization by backpropagating a squashed Jacobian instead. This parameter indicates
@@ -91,7 +91,7 @@ class Engine:
9191 weighting = UPGradWeighting()
9292
9393 # Create the engine before the backward pass, and only once.
94- engine = Engine(model.modules() , batch_dim=0)
94+ engine = Engine(model, batch_dim=0)
9595
9696 for input, target in zip(inputs, targets):
9797 output = model(input).squeeze(dim=1) # shape: [16]
@@ -173,7 +173,7 @@ class Engine:
173173
174174 def __init__ (
175175 self ,
176- modules : Iterable [ nn .Module ] ,
176+ * modules : nn .Module ,
177177 batch_dim : int | None ,
178178 ):
179179 self ._gramian_accumulator = GramianAccumulator ()
@@ -183,16 +183,16 @@ def __init__(
183183 self ._target_edges , self ._gramian_accumulator , batch_dim is not None
184184 )
185185
186- self ._hook_modules (modules )
186+ for module in modules :
187+ self ._hook_module_recursively (module )
187188
188- def _hook_modules (self , modules : Iterable [nn .Module ]) -> None :
189- _modules = set (modules )
190-
191- # Add module forward hooks to compute jacobians
192- for module in _modules :
193- if any (p .requires_grad for p in module .parameters (recurse = False )):
194- self ._check_module_is_compatible (module )
195- self ._module_hook_manager .hook_module (module )
189+ def _hook_module_recursively (self , module : nn .Module ) -> None :
190+ if any (p .requires_grad for p in module .parameters (recurse = False )):
191+ self ._check_module_is_compatible (module )
192+ self ._module_hook_manager .hook_module (module )
193+ else :
194+ for child in module .children ():
195+ self ._hook_module_recursively (child )
196196
197197 def _check_module_is_compatible (self , module : nn .Module ) -> None :
198198 if self ._batch_dim is not None :
0 commit comments