@@ -101,40 +101,49 @@ def __init__(
101101 self .gramian_accumulator = gramian_accumulator
102102 self .has_batch_dim = has_batch_dim
103103
104- def __call__ (self , module : nn .Module , args : PyTree , output : PyTree ) -> PyTree :
104+ def __call__ (self , module : nn .Module , args : PyTree , outputs : PyTree ) -> PyTree :
105105 if self .gramian_accumulation_phase :
106- return output
106+ return outputs
107107
108- flat_outputs , output_spec = tree_flatten (output )
108+ flat_outputs , output_spec = tree_flatten (outputs )
109109
110- if not any (isinstance (t , Tensor ) and t .requires_grad for t in flat_outputs ):
110+ rg_outputs = list [Tensor ]()
111+ rg_output_indices = list [int ]()
112+ for idx , output in enumerate (flat_outputs ):
113+ if isinstance (output , Tensor ) and output .requires_grad :
114+ rg_outputs .append (output )
115+ rg_output_indices .append (idx )
116+
117+ if len (rg_outputs ) == 0 :
111118 # This can happen only if a module has a trainable param but outputs no tensor that
112119 # require grad
113- return output
120+ return outputs
114121
115122 requires_grad_params = [p for p in module .parameters (recurse = False ) if p .requires_grad ]
116123 self .gramian_accumulator .track_parameter_paths (requires_grad_params )
117124
118125 # We only care about running the JacobianAccumulator node, so we need one of its child
119126 # edges (the edges of the original outputs of the model) as target. For memory
120127 # efficiency, we select the smallest one (that requires grad).
121- inf = float ("inf" )
122- preference = torch .tensor ([t .numel () if t .requires_grad else inf for t in flat_outputs ])
128+ preference = torch .tensor ([t .numel () for t in rg_outputs ])
123129 index = cast (int , preference .argmin ().item ())
124- self .target_edges .register (get_gradient_edge (flat_outputs [index ]))
130+ self .target_edges .register (get_gradient_edge (rg_outputs [index ]))
125131
126- vjp = FunctionalVJP (module ) if self .has_batch_dim else AutogradVJP (module , flat_outputs )
132+ vjp = FunctionalVJP (module ) if self .has_batch_dim else AutogradVJP (module , rg_outputs )
127133
128- autograd_fn_outputs = JacobianAccumulator .apply (
134+ autograd_fn_rg_outputs = JacobianAccumulator .apply (
129135 self .gramian_accumulation_phase ,
130136 vjp ,
131137 args ,
132138 self .gramian_accumulator ,
133139 module ,
134- * flat_outputs ,
140+ * rg_outputs ,
135141 )
136142
137- return tree_unflatten (autograd_fn_outputs , output_spec )
143+ for idx , output in zip (rg_output_indices , autograd_fn_rg_outputs ):
144+ flat_outputs [idx ] = output
145+
146+ return tree_unflatten (flat_outputs , output_spec )
138147
139148
140149class JacobianAccumulator (torch .autograd .Function ):
@@ -155,9 +164,9 @@ def forward(
155164 args : PyTree ,
156165 gramian_accumulator : GramianAccumulator ,
157166 module : nn .Module ,
158- * xs : Tensor ,
167+ * rg_tensors : Tensor ,
159168 ) -> tuple [Tensor , ...]:
160- return tuple (x .detach () for x in xs )
169+ return tuple (t .detach () for t in rg_tensors )
161170
162171 # For Python version > 3.10, the type of `inputs` should become
163172 # tuple[BoolRef, VJP, PyTree, GramianAccumulator, nn.Module, *tuple[Tensor, ...]]
0 commit comments