Skip to content

Commit 63d9d84

Browse files
refactor(autogram): Switch to ComputeModuleJacobians (#452)
* Make `AccumulateJacobian` return the Jacobian w.r.t. the parameters of the module rather than providing them to the `GramianAccumulator` directly. * Rename the class `AccumulateJacobian` to `ComputeModuleJacobians` to better reflect its role. --------- Co-authored-by: Valérian Rey <[email protected]>
1 parent 19d375d commit 63d9d84

File tree

1 file changed

+11
-14
lines changed

1 file changed

+11
-14
lines changed

src/torchjd/autogram/_module_hook_manager.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -206,52 +206,49 @@ def backward(ctx, *grad_outputs: Tensor) -> tuple:
206206
if not ctx.gramian_accumulation_phase:
207207
return None, None, None, None, None, None, *grad_outputs
208208

209-
AccumulateJacobian.apply(
209+
path_jacobians = ComputeModuleJacobians.apply(
210210
ctx.vjp,
211211
ctx.args,
212212
ctx.kwargs,
213-
ctx.gramian_accumulator,
214213
ctx.module,
215214
*grad_outputs,
216215
)
216+
ctx.gramian_accumulator.accumulate_path_jacobians(path_jacobians)
217217

218218
return None, None, None, None, None, None, *grad_outputs
219219

220220

221-
class AccumulateJacobian(torch.autograd.Function):
221+
class ComputeModuleJacobians(torch.autograd.Function):
222222

223223
@staticmethod
224224
def forward(
225225
vjp: VJP,
226226
args: tuple[PyTree, ...],
227227
kwargs: dict[str, PyTree],
228-
gramian_accumulator: GramianAccumulator,
229228
module: nn.Module,
230229
*grad_outputs: Tensor,
231-
) -> None:
230+
) -> dict[Tensor, Tensor]:
232231
# There is no non-batched dimension
233232
generalized_jacobians = vjp(grad_outputs, args, kwargs)
234-
path_jacobians = AccumulateJacobian._make_path_jacobians(module, generalized_jacobians)
235-
gramian_accumulator.accumulate_path_jacobians(path_jacobians)
233+
path_jacobians = ComputeModuleJacobians._make_path_jacobians(module, generalized_jacobians)
234+
return path_jacobians
236235

237236
@staticmethod
238237
def vmap(
239238
_,
240-
in_dims: tuple, # tuple[None, tuple[PyTree, ...], dict[str, PyTree], None, None, *tuple[int | None, ...]]
239+
in_dims: tuple, # tuple[None, tuple[PyTree, ...], dict[str, PyTree], None, *tuple[int | None, ...]]
241240
vjp: VJP,
242241
args: tuple[PyTree, ...],
243242
kwargs: dict[str, PyTree],
244-
gramian_accumulator: GramianAccumulator,
245243
module: nn.Module,
246244
*jac_outputs: Tensor,
247-
) -> tuple[None, None]:
245+
) -> tuple[dict[Tensor, Tensor], None]:
248246
# There is a non-batched dimension
249247
# We do not vmap over the args for the non-batched dimension
250-
in_dims = (in_dims[5:], tree_map(lambda _: None, args), tree_map(lambda _: None, kwargs))
248+
in_dims = (in_dims[4:], tree_map(lambda _: None, args), tree_map(lambda _: None, kwargs))
251249
generalized_jacobians = torch.vmap(vjp, in_dims=in_dims)(jac_outputs, args, kwargs)
252-
path_jacobians = AccumulateJacobian._make_path_jacobians(module, generalized_jacobians)
253-
gramian_accumulator.accumulate_path_jacobians(path_jacobians)
254-
return None, None
250+
path_jacobians = ComputeModuleJacobians._make_path_jacobians(module, generalized_jacobians)
251+
return path_jacobians, None
255252

256253
@staticmethod
257254
def _make_path_jacobians(

0 commit comments

Comments
 (0)