Skip to content

Commit d71cbe4

Browse files
authored
refactor(autogram): Reorder code (#436)
1 parent b542e9f commit d71cbe4

File tree

1 file changed

+101
-101
lines changed

1 file changed

+101
-101
lines changed

src/torchjd/autogram/_module_hook_manager.py

Lines changed: 101 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,6 @@
2020
# still support older versions of PyTorch where pytree is protected).
2121

2222

23-
class BoolRef:
24-
"""Class wrapping a boolean value, acting as a reference to this boolean value."""
25-
26-
def __init__(self, value: bool):
27-
self.value = value
28-
29-
def __bool__(self) -> bool:
30-
return self.value
31-
32-
3323
class ModuleHookManager:
3424
"""
3525
Class responsible for handling hooks and Nodes that computes the Gramian reverse accumulation.
@@ -88,58 +78,64 @@ def remove_hooks(handles: list[TorchRemovableHandle]) -> None:
8878
handle.remove()
8979

9080

91-
class AccumulateJacobian(torch.autograd.Function):
81+
class BoolRef:
82+
"""Class wrapping a boolean value, acting as a reference to this boolean value."""
9283

93-
@staticmethod
94-
def forward(
95-
output_spec: TreeSpec,
96-
vjp: VJP,
97-
args: PyTree,
98-
gramian_accumulator: GramianAccumulator,
99-
module: nn.Module,
100-
*flat_grad_outputs: Tensor,
101-
) -> None:
102-
# There is no non-batched dimension
103-
grad_outputs = tree_unflatten(flat_grad_outputs, output_spec)
104-
generalized_jacobians = vjp(grad_outputs, args)
105-
path_jacobians = AccumulateJacobian._make_path_jacobians(module, generalized_jacobians)
106-
gramian_accumulator.accumulate_path_jacobians(path_jacobians)
84+
def __init__(self, value: bool):
85+
self.value = value
10786

108-
@staticmethod
109-
def vmap(
110-
_,
111-
in_dims: PyTree,
112-
output_spec: TreeSpec,
113-
vjp: VJP,
114-
args: PyTree,
87+
def __bool__(self) -> bool:
88+
return self.value
89+
90+
91+
class Hook:
92+
def __init__(
93+
self,
94+
gramian_accumulation_phase: BoolRef,
95+
target_edges: EdgeRegistry,
11596
gramian_accumulator: GramianAccumulator,
116-
module: nn.Module,
117-
*flat_jac_outputs: Tensor,
118-
) -> tuple[None, None]:
119-
# There is a non-batched dimension
120-
jac_outputs = tree_unflatten(flat_jac_outputs, output_spec)
121-
# We do not vmap over the args for the non-batched dimension
122-
in_dims = (tree_unflatten(in_dims[5:], output_spec), tree_map(lambda _: None, args))
123-
generalized_jacobians = torch.vmap(vjp, in_dims=in_dims)(jac_outputs, args)
124-
path_jacobians = AccumulateJacobian._make_path_jacobians(module, generalized_jacobians)
125-
gramian_accumulator.accumulate_path_jacobians(path_jacobians)
126-
return None, None
97+
has_batch_dim: bool,
98+
):
99+
self.gramian_accumulation_phase = gramian_accumulation_phase
100+
self.target_edges = target_edges
101+
self.gramian_accumulator = gramian_accumulator
102+
self.has_batch_dim = has_batch_dim
127103

128-
@staticmethod
129-
def _make_path_jacobians(
130-
module: nn.Module,
131-
generalized_jacobians: dict[str, Tensor],
132-
) -> dict[Tensor, Tensor]:
133-
path_jacobians: dict[Tensor, Tensor] = {}
134-
for param_name, generalized_jacobian in generalized_jacobians.items():
135-
key = module.get_parameter(param_name)
136-
jacobian = generalized_jacobian.reshape([-1] + list(key.shape))
137-
path_jacobians[key] = jacobian
138-
return path_jacobians
104+
def __call__(self, module: nn.Module, args: PyTree, output: PyTree) -> PyTree:
105+
if self.gramian_accumulation_phase:
106+
return output
139107

140-
@staticmethod
141-
def setup_context(*_):
142-
pass
108+
flat_outputs, output_spec = tree_flatten(output)
109+
110+
if not any(isinstance(t, Tensor) and t.requires_grad for t in flat_outputs):
111+
# This can happen only if a module has a trainable param but outputs no tensor that
112+
# require grad
113+
return output
114+
115+
requires_grad_params = [p for p in module.parameters(recurse=False) if p.requires_grad]
116+
self.gramian_accumulator.track_parameter_paths(requires_grad_params)
117+
118+
# We only care about running the JacobianAccumulator node, so we need one of its child
119+
# edges (the edges of the original outputs of the model) as target. For memory
120+
# efficiency, we select the smallest one (that requires grad).
121+
inf = float("inf")
122+
preference = torch.tensor([t.numel() if t.requires_grad else inf for t in flat_outputs])
123+
index = cast(int, preference.argmin().item())
124+
self.target_edges.register(get_gradient_edge(flat_outputs[index]))
125+
126+
vjp = FunctionalVJP(module) if self.has_batch_dim else AutogradVJP(module, flat_outputs)
127+
128+
autograd_fn_outputs = JacobianAccumulator.apply(
129+
self.gramian_accumulation_phase,
130+
output_spec,
131+
vjp,
132+
args,
133+
self.gramian_accumulator,
134+
module,
135+
*flat_outputs,
136+
)
137+
138+
return tree_unflatten(autograd_fn_outputs, output_spec)
143139

144140

145141
class JacobianAccumulator(torch.autograd.Function):
@@ -197,51 +193,55 @@ def backward(ctx, *flat_grad_outputs: Tensor):
197193
return None, None, None, None, None, None, *flat_grad_outputs
198194

199195

200-
class Hook:
201-
def __init__(
202-
self,
203-
gramian_accumulation_phase: BoolRef,
204-
target_edges: EdgeRegistry,
205-
gramian_accumulator: GramianAccumulator,
206-
has_batch_dim: bool,
207-
):
208-
self.gramian_accumulation_phase = gramian_accumulation_phase
209-
self.target_edges = target_edges
210-
self.gramian_accumulator = gramian_accumulator
211-
self.has_batch_dim = has_batch_dim
212-
213-
def __call__(self, module: nn.Module, args: PyTree, output: PyTree) -> PyTree:
214-
if self.gramian_accumulation_phase:
215-
return output
216-
217-
flat_outputs, output_spec = tree_flatten(output)
218-
219-
if not any(isinstance(t, Tensor) and t.requires_grad for t in flat_outputs):
220-
# This can happen only if a module has a trainable param but outputs no tensor that
221-
# require grad
222-
return output
223-
224-
requires_grad_params = [p for p in module.parameters(recurse=False) if p.requires_grad]
225-
self.gramian_accumulator.track_parameter_paths(requires_grad_params)
196+
class AccumulateJacobian(torch.autograd.Function):
226197

227-
# We only care about running the JacobianAccumulator node, so we need one of its child
228-
# edges (the edges of the original outputs of the model) as target. For memory
229-
# efficiency, we select the smallest one (that requires grad).
230-
inf = float("inf")
231-
preference = torch.tensor([t.numel() if t.requires_grad else inf for t in flat_outputs])
232-
index = cast(int, preference.argmin().item())
233-
self.target_edges.register(get_gradient_edge(flat_outputs[index]))
198+
@staticmethod
199+
def forward(
200+
output_spec: TreeSpec,
201+
vjp: VJP,
202+
args: PyTree,
203+
gramian_accumulator: GramianAccumulator,
204+
module: nn.Module,
205+
*flat_grad_outputs: Tensor,
206+
) -> None:
207+
# There is no non-batched dimension
208+
grad_outputs = tree_unflatten(flat_grad_outputs, output_spec)
209+
generalized_jacobians = vjp(grad_outputs, args)
210+
path_jacobians = AccumulateJacobian._make_path_jacobians(module, generalized_jacobians)
211+
gramian_accumulator.accumulate_path_jacobians(path_jacobians)
234212

235-
vjp = FunctionalVJP(module) if self.has_batch_dim else AutogradVJP(module, flat_outputs)
213+
@staticmethod
214+
def vmap(
215+
_,
216+
in_dims: PyTree,
217+
output_spec: TreeSpec,
218+
vjp: VJP,
219+
args: PyTree,
220+
gramian_accumulator: GramianAccumulator,
221+
module: nn.Module,
222+
*flat_jac_outputs: Tensor,
223+
) -> tuple[None, None]:
224+
# There is a non-batched dimension
225+
jac_outputs = tree_unflatten(flat_jac_outputs, output_spec)
226+
# We do not vmap over the args for the non-batched dimension
227+
in_dims = (tree_unflatten(in_dims[5:], output_spec), tree_map(lambda _: None, args))
228+
generalized_jacobians = torch.vmap(vjp, in_dims=in_dims)(jac_outputs, args)
229+
path_jacobians = AccumulateJacobian._make_path_jacobians(module, generalized_jacobians)
230+
gramian_accumulator.accumulate_path_jacobians(path_jacobians)
231+
return None, None
236232

237-
autograd_fn_outputs = JacobianAccumulator.apply(
238-
self.gramian_accumulation_phase,
239-
output_spec,
240-
vjp,
241-
args,
242-
self.gramian_accumulator,
243-
module,
244-
*flat_outputs,
245-
)
233+
@staticmethod
234+
def _make_path_jacobians(
235+
module: nn.Module,
236+
generalized_jacobians: dict[str, Tensor],
237+
) -> dict[Tensor, Tensor]:
238+
path_jacobians: dict[Tensor, Tensor] = {}
239+
for param_name, generalized_jacobian in generalized_jacobians.items():
240+
key = module.get_parameter(param_name)
241+
jacobian = generalized_jacobian.reshape([-1] + list(key.shape))
242+
path_jacobians[key] = jacobian
243+
return path_jacobians
246244

247-
return tree_unflatten(autograd_fn_outputs, output_spec)
245+
@staticmethod
246+
def setup_context(*_):
247+
pass

0 commit comments

Comments
 (0)