@@ -206,52 +206,49 @@ def backward(ctx, *grad_outputs: Tensor) -> tuple:
206206 if not ctx .gramian_accumulation_phase :
207207 return None , None , None , None , None , None , * grad_outputs
208208
209- AccumulateJacobian .apply (
209+ path_jacobians = ComputeModuleJacobians .apply (
210210 ctx .vjp ,
211211 ctx .args ,
212212 ctx .kwargs ,
213- ctx .gramian_accumulator ,
214213 ctx .module ,
215214 * grad_outputs ,
216215 )
216+ ctx .gramian_accumulator .accumulate_path_jacobians (path_jacobians )
217217
218218 return None , None , None , None , None , None , * grad_outputs
219219
220220
221- class AccumulateJacobian (torch .autograd .Function ):
221+ class ComputeModuleJacobians (torch .autograd .Function ):
222222
223223 @staticmethod
224224 def forward (
225225 vjp : VJP ,
226226 args : tuple [PyTree , ...],
227227 kwargs : dict [str , PyTree ],
228- gramian_accumulator : GramianAccumulator ,
229228 module : nn .Module ,
230229 * grad_outputs : Tensor ,
231- ) -> None :
230+ ) -> dict [ Tensor , Tensor ] :
232231 # There is no non-batched dimension
233232 generalized_jacobians = vjp (grad_outputs , args , kwargs )
234- path_jacobians = AccumulateJacobian ._make_path_jacobians (module , generalized_jacobians )
235- gramian_accumulator . accumulate_path_jacobians ( path_jacobians )
233+ path_jacobians = ComputeModuleJacobians ._make_path_jacobians (module , generalized_jacobians )
234+ return path_jacobians
236235
237236 @staticmethod
238237 def vmap (
239238 _ ,
240- in_dims : tuple , # tuple[None, tuple[PyTree, ...], dict[str, PyTree], None, None, *tuple[int | None, ...]]
239+ in_dims : tuple , # tuple[None, tuple[PyTree, ...], dict[str, PyTree], None, *tuple[int | None, ...]]
241240 vjp : VJP ,
242241 args : tuple [PyTree , ...],
243242 kwargs : dict [str , PyTree ],
244- gramian_accumulator : GramianAccumulator ,
245243 module : nn .Module ,
246244 * jac_outputs : Tensor ,
247- ) -> tuple [None , None ]:
245+ ) -> tuple [dict [ Tensor , Tensor ] , None ]:
248246 # There is a non-batched dimension
249247 # We do not vmap over the args for the non-batched dimension
250- in_dims = (in_dims [5 :], tree_map (lambda _ : None , args ), tree_map (lambda _ : None , kwargs ))
248+ in_dims = (in_dims [4 :], tree_map (lambda _ : None , args ), tree_map (lambda _ : None , kwargs ))
251249 generalized_jacobians = torch .vmap (vjp , in_dims = in_dims )(jac_outputs , args , kwargs )
252- path_jacobians = AccumulateJacobian ._make_path_jacobians (module , generalized_jacobians )
253- gramian_accumulator .accumulate_path_jacobians (path_jacobians )
254- return None , None
250+ path_jacobians = ComputeModuleJacobians ._make_path_jacobians (module , generalized_jacobians )
251+ return path_jacobians , None
255252
256253 @staticmethod
257254 def _make_path_jacobians (
0 commit comments