Skip to content

Commit bfce3ec

Browse files
authored
fix(autogram): Use only rg tensors in autograd functions (#441)
* Work with rg outputs only * Stop marking WithModuleWithStringOutput as xfail
1 parent 5ee7a11 commit bfce3ec

File tree

3 files changed

+33
-33
lines changed

3 files changed

+33
-33
lines changed

src/torchjd/autogram/_module_hook_manager.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -101,40 +101,49 @@ def __init__(
101101
self.gramian_accumulator = gramian_accumulator
102102
self.has_batch_dim = has_batch_dim
103103

104-
def __call__(self, module: nn.Module, args: PyTree, output: PyTree) -> PyTree:
104+
def __call__(self, module: nn.Module, args: PyTree, outputs: PyTree) -> PyTree:
105105
if self.gramian_accumulation_phase:
106-
return output
106+
return outputs
107107

108-
flat_outputs, output_spec = tree_flatten(output)
108+
flat_outputs, output_spec = tree_flatten(outputs)
109109

110-
if not any(isinstance(t, Tensor) and t.requires_grad for t in flat_outputs):
110+
rg_outputs = list[Tensor]()
111+
rg_output_indices = list[int]()
112+
for idx, output in enumerate(flat_outputs):
113+
if isinstance(output, Tensor) and output.requires_grad:
114+
rg_outputs.append(output)
115+
rg_output_indices.append(idx)
116+
117+
if len(rg_outputs) == 0:
111118
# This can happen only if a module has a trainable param but outputs no tensor that
112119
# require grad
113-
return output
120+
return outputs
114121

115122
requires_grad_params = [p for p in module.parameters(recurse=False) if p.requires_grad]
116123
self.gramian_accumulator.track_parameter_paths(requires_grad_params)
117124

118125
# We only care about running the JacobianAccumulator node, so we need one of its child
119126
# edges (the edges of the original outputs of the model) as target. For memory
120127
# efficiency, we select the smallest one (that requires grad).
121-
inf = float("inf")
122-
preference = torch.tensor([t.numel() if t.requires_grad else inf for t in flat_outputs])
128+
preference = torch.tensor([t.numel() for t in rg_outputs])
123129
index = cast(int, preference.argmin().item())
124-
self.target_edges.register(get_gradient_edge(flat_outputs[index]))
130+
self.target_edges.register(get_gradient_edge(rg_outputs[index]))
125131

126-
vjp = FunctionalVJP(module) if self.has_batch_dim else AutogradVJP(module, flat_outputs)
132+
vjp = FunctionalVJP(module) if self.has_batch_dim else AutogradVJP(module, rg_outputs)
127133

128-
autograd_fn_outputs = JacobianAccumulator.apply(
134+
autograd_fn_rg_outputs = JacobianAccumulator.apply(
129135
self.gramian_accumulation_phase,
130136
vjp,
131137
args,
132138
self.gramian_accumulator,
133139
module,
134-
*flat_outputs,
140+
*rg_outputs,
135141
)
136142

137-
return tree_unflatten(autograd_fn_outputs, output_spec)
143+
for idx, output in zip(rg_output_indices, autograd_fn_rg_outputs):
144+
flat_outputs[idx] = output
145+
146+
return tree_unflatten(flat_outputs, output_spec)
138147

139148

140149
class JacobianAccumulator(torch.autograd.Function):
@@ -155,9 +164,9 @@ def forward(
155164
args: PyTree,
156165
gramian_accumulator: GramianAccumulator,
157166
module: nn.Module,
158-
*xs: Tensor,
167+
*rg_tensors: Tensor,
159168
) -> tuple[Tensor, ...]:
160-
return tuple(x.detach() for x in xs)
169+
return tuple(t.detach() for t in rg_tensors)
161170

162171
# For Python version > 3.10, the type of `inputs` should become
163172
# tuple[BoolRef, VJP, PyTree, GramianAccumulator, nn.Module, *tuple[Tensor, ...]]

src/torchjd/autogram/_vjp.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -70,16 +70,18 @@ def _call_on_one_instance(
7070
args_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), args_j)
7171
grad_outputs_j_ = [x.unsqueeze(0) for x in grad_outputs_j]
7272

73-
def flat_functional_model_call(trainable_params: dict[str, Parameter]) -> list[Tensor]:
73+
def functional_model_call(trainable_params: dict[str, Parameter]) -> list[Tensor]:
7474
all_state = {
7575
**trainable_params,
7676
**dict(self.module.named_buffers()),
7777
**self.frozen_params,
7878
}
7979
output = torch.func.functional_call(self.module, all_state, args_j)
80-
return tree_flatten(output)[0]
80+
flat_outputs = tree_flatten(output)[0]
81+
rg_outputs = [t for t in flat_outputs if isinstance(t, Tensor) and t.requires_grad]
82+
return rg_outputs
8183

82-
vjp_func = torch.func.vjp(flat_functional_model_call, self.trainable_params)[1]
84+
vjp_func = torch.func.vjp(functional_model_call, self.trainable_params)[1]
8385

8486
# vjp_func is a function that computes the vjp w.r.t. to the primals (tuple). Here the
8587
# functional has a single primal which is dict(module.named_parameters()). We therefore take
@@ -95,28 +97,17 @@ class AutogradVJP(ModuleVJP):
9597
forward pass.
9698
"""
9799

98-
def __init__(self, module: nn.Module, outputs: Sequence[Tensor]):
100+
def __init__(self, module: nn.Module, rg_outputs: Sequence[Tensor]):
99101
super().__init__(module)
100102

101-
self.outputs_that_require_grad = list[Tensor]()
102-
self.mask = list[bool]()
103-
for output in outputs:
104-
requires_grad = output.requires_grad
105-
if requires_grad:
106-
self.outputs_that_require_grad.append(output)
107-
self.mask.append(requires_grad)
108-
103+
self.rg_outputs = rg_outputs
109104
self.flat_trainable_params, self.param_spec = tree_flatten(self.trainable_params)
110105

111106
def __call__(self, grad_outputs: tuple[Tensor, ...], _: PyTree) -> dict[str, Tensor]:
112-
113-
# Only keep the grad_outputs corresponding to outputs that require grad.
114-
grad_outputs_ = [grad_output for grad_output, rg in zip(grad_outputs, self.mask) if rg]
115-
116107
grads = torch.autograd.grad(
117-
self.outputs_that_require_grad,
108+
self.rg_outputs,
118109
self.flat_trainable_params,
119-
grad_outputs_,
110+
grad_outputs,
120111
retain_graph=True,
121112
allow_unused=True,
122113
materialize_grads=True,

tests/unit/autogram/test_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@
109109
(Ndim3Output, 32),
110110
(Ndim4Output, 32),
111111
(WithDropout, 32),
112+
(WithModuleWithStringOutput, 32),
112113
(FreeParam, 32),
113114
(NoFreeParam, 32),
114115
param(Cifar10Model, 16, marks=mark.slow),
@@ -167,7 +168,6 @@ def test_compute_gramian(architecture: type[ShapedModule], batch_size: int, batc
167168
WithModuleTrackingRunningStats,
168169
param(WithRNN, marks=mark.xfail_if_cuda),
169170
WithModuleWithStringArg,
170-
param(WithModuleWithStringOutput, marks=mark.xfail),
171171
],
172172
)
173173
@mark.parametrize("batch_size", [1, 3, 32])

0 commit comments

Comments
 (0)