@@ -101,7 +101,7 @@ def __init__(
101101 self .gramian_accumulator = gramian_accumulator
102102 self .has_batch_dim = has_batch_dim
103103
104- def __call__ (self , module : nn .Module , args : PyTree , outputs : PyTree ) -> PyTree :
104+ def __call__ (self , module : nn .Module , args : tuple [ PyTree , ...] , outputs : PyTree ) -> PyTree :
105105 if self .gramian_accumulation_phase :
106106 return outputs
107107
@@ -129,7 +129,14 @@ def __call__(self, module: nn.Module, args: PyTree, outputs: PyTree) -> PyTree:
129129 index = cast (int , preference .argmin ().item ())
130130 self .target_edges .register (get_gradient_edge (rg_outputs [index ]))
131131
132- vjp = FunctionalVJP (module ) if self .has_batch_dim else AutogradVJP (module , rg_outputs )
132+ vjp : VJP
133+ if self .has_batch_dim :
134+ rg_outputs_in_dims = (0 ,) * len (rg_outputs )
135+ args_in_dims = tree_map (lambda t : 0 if isinstance (t , Tensor ) else None , args )
136+ in_dims = (rg_outputs_in_dims , args_in_dims )
137+ vjp = FunctionalVJP (module , in_dims )
138+ else :
139+ vjp = AutogradVJP (module , rg_outputs )
133140
134141 autograd_fn_rg_outputs = JacobianAccumulator .apply (
135142 self .gramian_accumulation_phase ,
@@ -161,15 +168,15 @@ class JacobianAccumulator(torch.autograd.Function):
161168 def forward (
162169 gramian_accumulation_phase : BoolRef ,
163170 vjp : VJP ,
164- args : PyTree ,
171+ args : tuple [ PyTree , ...] ,
165172 gramian_accumulator : GramianAccumulator ,
166173 module : nn .Module ,
167174 * rg_tensors : Tensor ,
168175 ) -> tuple [Tensor , ...]:
169176 return tuple (t .detach () for t in rg_tensors )
170177
171178 # For Python version > 3.10, the type of `inputs` should become
172- # tuple[BoolRef, VJP, PyTree, GramianAccumulator, nn.Module, *tuple[Tensor, ...]]
179+ # tuple[BoolRef, VJP, tuple[ PyTree, ...] , GramianAccumulator, nn.Module, *tuple[Tensor, ...]]
173180 @staticmethod
174181 def setup_context (
175182 ctx ,
@@ -183,7 +190,9 @@ def setup_context(
183190 ctx .module = inputs [4 ]
184191
185192 @staticmethod
186- def backward (ctx , * grad_outputs : Tensor ):
193+ def backward (ctx , * grad_outputs : Tensor ) -> tuple :
194+ # Return type for python > 3.10: # tuple[None, None, None, None, None, *tuple[Tensor, ...]]
195+
187196 if not ctx .gramian_accumulation_phase :
188197 return None , None , None , None , None , * grad_outputs
189198
@@ -203,7 +212,7 @@ class AccumulateJacobian(torch.autograd.Function):
203212 @staticmethod
204213 def forward (
205214 vjp : VJP ,
206- args : PyTree ,
215+ args : tuple [ PyTree , ...] ,
207216 gramian_accumulator : GramianAccumulator ,
208217 module : nn .Module ,
209218 * grad_outputs : Tensor ,
@@ -216,9 +225,9 @@ def forward(
216225 @staticmethod
217226 def vmap (
218227 _ ,
219- in_dims : PyTree ,
228+ in_dims : tuple , # tuple[None, tuple[ PyTree, ...], None, None, *tuple[int | None, ...]]
220229 vjp : VJP ,
221- args : PyTree ,
230+ args : tuple [ PyTree , ...] ,
222231 gramian_accumulator : GramianAccumulator ,
223232 module : nn .Module ,
224233 * jac_outputs : Tensor ,
@@ -244,5 +253,5 @@ def _make_path_jacobians(
244253 return path_jacobians
245254
246255 @staticmethod
247- def setup_context (* _ ):
256+ def setup_context (* _ ) -> None :
248257 pass
0 commit comments