Skip to content

Commit f5a8996

Browse files
authored
fix(autogram): Remove reference cycles (#424)
* Refactor ModuleHookManager * Refactor autograd Functions * Add finalizer in ModuleHookManager to unhook * Remove manual garbage collection
1 parent d3a9e8b commit f5a8996

File tree

4 files changed

+145
-92
lines changed

4 files changed

+145
-92
lines changed

src/torchjd/autogram/_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,13 +170,13 @@ def compute_gramian(self, output: Tensor) -> Tensor:
170170

171171
reshaped_output = output.reshape([-1])
172172

173-
self._module_hook_manager.gramian_accumulation_phase = True
173+
self._module_hook_manager.gramian_accumulation_phase.value = True
174174

175175
try:
176176
square_gramian = self._compute_square_gramian(reshaped_output)
177177
finally:
178178
# Reset everything that has a state, even if the previous call raised an exception
179-
self._module_hook_manager.gramian_accumulation_phase = False
179+
self._module_hook_manager.gramian_accumulation_phase.value = False
180180
self._gramian_accumulator.reset()
181181
self._target_edges.reset()
182182

Lines changed: 137 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import weakref
2+
from collections.abc import Callable
13
from typing import cast
24

35
import torch
@@ -19,6 +21,16 @@
1921
# still support older versions of PyTorch where pytree is protected).
2022

2123

24+
class BoolRef:
25+
"""Class wrapping a boolean value, acting as a reference to this boolean value."""
26+
27+
def __init__(self, value: bool):
28+
self.value = value
29+
30+
def __bool__(self) -> bool:
31+
return self.value
32+
33+
2234
class ModuleHookManager:
2335
"""
2436
Class responsible for handling hooks and Nodes that computes the Gramian reverse accumulation.
@@ -35,9 +47,19 @@ def __init__(
3547
):
3648
self._target_edges = target_edges
3749
self._gramian_accumulator = gramian_accumulator
38-
self.gramian_accumulation_phase = False
50+
self.gramian_accumulation_phase = BoolRef(False)
3951
self._handles: list[TorchRemovableHandle] = []
4052

53+
# When the ModuleHookManager is not referenced anymore, there is no reason to keep the hooks
54+
# alive. In fact, keeping the hooks alive would also keep the target edges alive, which
55+
# would keep the graph or part of the graph alive. Since the graph contains nodes that store
56+
# the module in their context, which themselves reference their hooks, the hooks will be
57+
# caught in a reference cycle and will not be freed by the garbage collector. It is thus
58+
# important to remove the hooks whenever we're sure we won't need them anymore.
59+
# We could have used a __del__ method here, with the same effects, but weakref.finalize
60+
# seems to be a better practice (and it only works if the function to call is static).
61+
self._finalizer = weakref.finalize(self, ModuleHookManager.remove_hooks, self._handles)
62+
4163
def hook_module(self, module: nn.Module) -> None:
4264
"""
4365
Add a module hook used to insert Jacobian accumulation nodes into the backward graph.
@@ -46,85 +68,133 @@ def hook_module(self, module: nn.Module) -> None:
4668
enabling Gramian computation.
4769
"""
4870

49-
def module_hook(_: nn.Module, args: PyTree, output: PyTree) -> PyTree:
50-
if self.gramian_accumulation_phase:
51-
return output
52-
53-
flat_outputs, tree_spec = tree_flatten(output)
71+
hook = Hook(self.gramian_accumulation_phase, self._target_edges, self._gramian_accumulator)
72+
self._handles.append(module.register_forward_hook(hook))
5473

55-
if not any(isinstance(t, Tensor) for t in flat_outputs):
56-
# This can happen only if a module returns no Tensor, for instance some niche usage
57-
# such as a module that prints something.
58-
return output
59-
60-
requires_grad_params = [p for p in module.parameters(recurse=False) if p.requires_grad]
61-
self._gramian_accumulator.track_parameter_paths(requires_grad_params)
74+
@staticmethod
75+
def remove_hooks(handles: list[TorchRemovableHandle]) -> None:
76+
"""
77+
Remove all registered hooks. This method is deliberately static so that it can be called by
78+
weakref.finalize.
79+
"""
6280

63-
# We only care about running the JacobianAccumulator node, so we need one of its child
64-
# edges (the edges of the original ouputs of the model) as target. For memory
65-
# efficiency, we select the smallest one (that requires grad).
66-
inf = float("inf")
67-
preference = torch.tensor([t.numel() if t.requires_grad else inf for t in flat_outputs])
68-
index = cast(int, preference.argmin().item())
69-
self._target_edges.register(get_gradient_edge(flat_outputs[index]))
81+
for handle in handles:
82+
handle.remove()
7083

71-
return self._apply_jacobian_accumulator(module, args, tree_spec, flat_outputs)
7284

73-
handle = module.register_forward_hook(module_hook)
74-
self._handles.append(handle)
85+
class AccumulateJacobian(torch.autograd.Function):
7586

76-
def _apply_jacobian_accumulator(
77-
self,
78-
module: nn.Module,
79-
args: PyTree,
87+
@staticmethod
88+
def forward(
89+
ctx,
8090
tree_spec: TreeSpec,
81-
flat_outputs: list[Tensor],
82-
) -> PyTree:
83-
vjp = torch.vmap(get_functional_vjp(module))
84-
85-
class AccumulateJacobian(torch.autograd.Function):
86-
87-
@staticmethod
88-
def forward(*flat_grad_outputs: Tensor) -> None:
89-
grad_outputs = tree_unflatten(flat_grad_outputs, tree_spec)
90-
jacobians = vjp(grad_outputs, args)
91-
self._gramian_accumulator.accumulate_path_jacobians(
92-
{
93-
module.get_parameter(param_name): jacobian
94-
for param_name, jacobian in jacobians.items()
95-
}
96-
)
91+
vjp: Callable[[PyTree, PyTree], dict[str, Tensor]],
92+
args: PyTree,
93+
gramian_accumulator: GramianAccumulator,
94+
module: nn.Module,
95+
*flat_grad_outputs: Tensor,
96+
) -> None:
97+
grad_outputs = tree_unflatten(flat_grad_outputs, tree_spec)
98+
jacobians = vjp(grad_outputs, args)
99+
gramian_accumulator.accumulate_path_jacobians(
100+
{
101+
module.get_parameter(param_name): jacobian
102+
for param_name, jacobian in jacobians.items()
103+
}
104+
)
105+
106+
107+
class JacobianAccumulator(torch.autograd.Function):
108+
"""
109+
Autograd function that accumulates Jacobian Gramians during the first backward pass.
97110
98-
@staticmethod
99-
def setup_context(*_):
100-
pass
111+
Acts as identity on forward pass. During the autogram algorithm, computes the Jacobian
112+
of outputs w.r.t. module parameters and feeds it to the gramian accumulator. Uses a
113+
toggle mechanism to activate only during the Gramian accumulation phase.
114+
"""
101115

102-
class JacobianAccumulator(torch.autograd.Function):
103-
"""
104-
Autograd function that accumulates Jacobian Gramians during the first backward pass.
116+
generate_vmap_rule = True
105117

106-
Acts as identity on forward pass. During the autogram algorithm, computes the Jacobian
107-
of outputs w.r.t. module parameters and feeds it to the gramian accumulator. Uses a
108-
toggle mechanism to activate only during the Gramian accumulation phase.
109-
"""
118+
@staticmethod
119+
def forward(
120+
ctx,
121+
gramian_accumulation_phase: BoolRef,
122+
tree_spec: TreeSpec,
123+
vjp: Callable[[PyTree, PyTree], dict[str, Tensor]],
124+
args: PyTree,
125+
gramian_accumulator: GramianAccumulator,
126+
module: nn.Module,
127+
*xs: Tensor,
128+
) -> tuple[Tensor, ...]:
129+
ctx.gramian_accumulation_phase = gramian_accumulation_phase
130+
ctx.tree_spec = tree_spec
131+
ctx.vjp = vjp
132+
ctx.args = args
133+
ctx.gramian_accumulator = gramian_accumulator
134+
ctx.module = module
135+
return tuple([x.detach() for x in xs])
136+
137+
@staticmethod
138+
def backward(ctx, *flat_grad_outputs: Tensor):
139+
if not ctx.gramian_accumulation_phase:
140+
return None, None, None, None, None, None, *flat_grad_outputs
141+
142+
AccumulateJacobian.apply(
143+
ctx.tree_spec,
144+
ctx.vjp,
145+
ctx.args,
146+
ctx.gramian_accumulator,
147+
ctx.module,
148+
*flat_grad_outputs,
149+
)
150+
151+
return None, None, None, None, None, None, *flat_grad_outputs
152+
153+
154+
class Hook:
155+
def __init__(
156+
self,
157+
gramian_accumulation_phase: BoolRef,
158+
target_edges: EdgeRegistry,
159+
gramian_accumulator: GramianAccumulator,
160+
):
161+
self.gramian_accumulation_phase = gramian_accumulation_phase
162+
self.target_edges = target_edges
163+
self.gramian_accumulator = gramian_accumulator
110164

111-
generate_vmap_rule = True
165+
def __call__(self, module: nn.Module, args: PyTree, output: PyTree) -> PyTree:
166+
if self.gramian_accumulation_phase:
167+
return output
112168

113-
@staticmethod
114-
def forward(*xs: Tensor) -> tuple[Tensor, ...]:
115-
return tuple([x.detach() for x in xs])
169+
flat_outputs, tree_spec = tree_flatten(output)
116170

117-
@staticmethod
118-
def setup_context(*_):
119-
pass
171+
if not any(isinstance(t, Tensor) for t in flat_outputs):
172+
# This can happen only if a module returns no Tensor, for instance some niche usage
173+
# such as a module that prints something.
174+
return output
120175

121-
@staticmethod
122-
def backward(ctx, *flat_grad_outputs: Tensor):
123-
if not self.gramian_accumulation_phase:
124-
return flat_grad_outputs
176+
requires_grad_params = [p for p in module.parameters(recurse=False) if p.requires_grad]
177+
self.gramian_accumulator.track_parameter_paths(requires_grad_params)
125178

126-
AccumulateJacobian.apply(*flat_grad_outputs)
179+
# We only care about running the JacobianAccumulator node, so we need one of its child
180+
# edges (the edges of the original ouputs of the model) as target. For memory
181+
# efficiency, we select the smallest one (that requires grad).
182+
inf = float("inf")
183+
preference = torch.tensor([t.numel() if t.requires_grad else inf for t in flat_outputs])
184+
index = cast(int, preference.argmin().item())
185+
self.target_edges.register(get_gradient_edge(flat_outputs[index]))
127186

128-
return flat_grad_outputs
187+
vjp = torch.vmap(get_functional_vjp(module))
129188

130-
return tree_unflatten(JacobianAccumulator.apply(*flat_outputs), tree_spec)
189+
return tree_unflatten(
190+
JacobianAccumulator.apply(
191+
self.gramian_accumulation_phase,
192+
tree_spec,
193+
vjp,
194+
args,
195+
self.gramian_accumulator,
196+
module,
197+
*flat_outputs,
198+
),
199+
tree_spec,
200+
)

tests/unit/autogram/test_engine.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,12 @@
9696
(FreeParam, 32),
9797
(NoFreeParam, 32),
9898
param(Randomness, 32, marks=mark.xfail),
99-
param(Cifar10Model, 16, marks=[mark.slow, mark.garbage_collect]),
100-
param(AlexNet, 2, marks=[mark.slow, mark.garbage_collect]),
101-
param(InstanceNormResNet18, 4, marks=[mark.slow, mark.garbage_collect]),
102-
param(GroupNormMobileNetV3Small, 3, marks=[mark.slow, mark.garbage_collect]),
103-
param(SqueezeNet, 8, marks=[mark.slow, mark.garbage_collect]),
104-
param(InstanceNormMobileNetV2, 2, marks=[mark.slow, mark.garbage_collect]),
99+
param(Cifar10Model, 16, marks=[mark.slow]),
100+
param(AlexNet, 2, marks=[mark.slow]),
101+
param(InstanceNormResNet18, 4, marks=[mark.slow]),
102+
param(GroupNormMobileNetV3Small, 3, marks=[mark.slow]),
103+
param(SqueezeNet, 8, marks=[mark.slow]),
104+
param(InstanceNormMobileNetV2, 2, marks=[mark.slow]),
105105
]
106106

107107

tests/unit/conftest.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import gc
21
import os
32
import random as rand
43

@@ -32,28 +31,12 @@ def fix_randomness() -> None:
3231
torch.use_deterministic_algorithms(True)
3332

3433

35-
@fixture(autouse=True)
36-
def garbage_collect_if_marked(request):
37-
"""
38-
Since garbage collection takes some time, we only do it when needed (when the test or the
39-
parametrization of the test is marked with mark.garbage_collect). This is currently useful for
40-
freeing CUDA memory after a lot has been allocated.
41-
"""
42-
43-
yield
44-
if request.node.get_closest_marker("garbage_collect"):
45-
if DEVICE.type == "cuda":
46-
torch.cuda.empty_cache()
47-
gc.collect()
48-
49-
5034
def pytest_addoption(parser):
5135
parser.addoption("--runslow", action="store_true", default=False, help="run slow tests")
5236

5337

5438
def pytest_configure(config):
5539
config.addinivalue_line("markers", "slow: mark test as slow to run")
56-
config.addinivalue_line("markers", "garbage_collect: do garbage collection after test")
5740

5841

5942
def pytest_collection_modifyitems(config, items):

0 commit comments

Comments
 (0)