@@ -141,47 +141,69 @@ def linear_process_input(x: Tensor, layer: Linear, kfac_approx: str) -> Tensor:
141141
142142
143143def process_grad_output (
144- grad_output : Tensor , module : Module , batch_averaged : bool , kfac_approx : str
144+ grad_output : Tensor ,
145+ module : Module ,
146+ loss_average : Union [None , str ],
147+ kfac_approx : str ,
145148) -> Tensor :
146149 """Reshape output gradients into matrices and apply scaling.
147150
148151 Args:
149152 grad_output: The gradient w.r.t. the output of the module.
150153 module: The module.
151- batch_averaged: Whether the loss is a mean over per-sample losses.
154+ loss_average: Whether the loss function is a mean over per-sample
155+ losses and if yes, over which dimensions the mean is taken.
156+ If `"batch"`, the loss function is a mean over as many terms as
157+ the size of the mini-batch. If `"batch+sequence"`, the loss
158+ function is a mean over as many terms as the size of the
159+ mini-batch times the sequence length, e.g. in the case of
160+ language modeling. If `None`, the loss function is a sum. This
161+ argument is used to ensure that the preconditioner is scaled
162+ consistently with the loss and the gradient. Default: `"batch"`.
152163 kfac_approx: The KFAC approximation to use for linear weight-sharing
153164 layers. Possible values are `"expand"` and `"reduce"`.
154165
155166 Returns:
156167 The processed output gradient.
157168
158169 Raises:
170+ AssertionError: If `loss_average` is not `None`, `"batch"`, or
171+ `"batch+sequence"`.
159172 AssertionError: If `kfac_approx` is neither `"expand"` nor `"reduce"`.
160173 NotImplementedError: If the module is not supported.
161174 """
175+ assert loss_average in {None , "batch" , "batch+sequence" }
162176 assert kfac_approx in {"expand" , "reduce" }
163177 grad_scaling = 1.0
164178 if isinstance (module , Conv2d ):
165179 return conv2d_process_grad_output (
166- grad_output , batch_averaged , grad_scaling , kfac_approx
180+ grad_output , loss_average , grad_scaling , kfac_approx
167181 )
168182 elif isinstance (module , Linear ):
169183 return linear_process_grad_output (
170- grad_output , batch_averaged , grad_scaling , kfac_approx
184+ grad_output , loss_average , grad_scaling , kfac_approx
171185 )
172186 else :
173187 raise NotImplementedError (f"Can't process grad_output for { module } ." )
174188
175189
176190def conv2d_process_grad_output (
177- g : Tensor , batch_averaged : bool , scaling : float , kfac_approx : str
191+ g : Tensor , loss_average : Union [ None , str ] , scaling : float , kfac_approx : str
178192) -> Tensor :
179193 """Process the output gradient of a convolution before the self-inner product.
180194
181195 Args:
182196 g: Gradient w.r.t. the output of a convolution. Has shape
183197 `[batch_size, C_out, O1, O2]`.
184- batch_averaged: Whether to multiply with the batch size.
198+ loss_average: Whether the loss function is a mean over per-sample
199+ losses and if yes, over which dimensions the mean is taken.
200+ If `"batch"`, the loss function is a mean over as many terms as
201+ the size of the mini-batch. If `"batch+sequence"`, the loss
202+ function is a mean over as many terms as the size of the
203+ mini-batch times the sequence length, e.g. in the case of
204+ language modeling. If `None`, the loss function is a sum. This
205+ argument is used to ensure that the preconditioner is scaled
206+ consistently with the loss and the gradient. Default: `"batch"`.
185207 scaling: An additional scaling that will be applied to the gradient.
186208 kfac_approx: The KFAC approximation to use. Possible values are
187209 `"expand"` and `"reduce"`.
@@ -190,11 +212,14 @@ def conv2d_process_grad_output(
190212 The processed scaled gradient. Has shape `[batch_size, C_out]` for
191213 `"reduce"` and `[batch_size * O1 * O2, C_out]` for `"expand"`.
192214 """
193- # The scaling by `sqrt(batch_size)` when `batch_averaged=True` assumes
194- # that we are in the reduce setting, i.e. the number of loss terms equals
195- # the batch size.
196- batch_size = g .shape [0 ]
197- scaling = scaling * sqrt (batch_size ) if batch_averaged else scaling
215+ # We have to adjust the scaling to account for the mean reduction of the
216+ # loss used for computing the gradients when loss_average is not None.
217+ if loss_average is not None :
218+ num_loss_terms = g .shape [0 ] # batch_size
219+ if loss_average == "batch+sequence" :
220+ num_loss_terms *= g .shape [2 :].numel () # spatial size = O1 * O2
221+
222+ scaling *= sqrt (num_loss_terms )
198223
199224 if kfac_approx == "expand" :
200225 # KFAC-expand approximation
@@ -207,15 +232,23 @@ def conv2d_process_grad_output(
207232
208233
209234def linear_process_grad_output (
210- g : Tensor , batch_averaged : bool , scaling : float , kfac_approx : str
235+ g : Tensor , loss_average : Union [ None , str ] , scaling : float , kfac_approx : str
211236) -> Tensor :
212237 """Process the output gradient of a linear layer before the self-inner product.
213238
214239 Args:
215240 g: Gradient w.r.t. the output of a linear layer. Has shape
216241 `[batch_size, ..., d_out]` where `...` is an arbitrary number of
217242 weight-shared dimensions.
218- batch_averaged: Whether to multiply with the batch size.
243+ loss_average: Whether the loss function is a mean over per-sample
244+ losses and if yes, over which dimensions the mean is taken.
245+ If `"batch"`, the loss function is a mean over as many terms as
246+ the size of the mini-batch. If `"batch+sequence"`, the loss
247+ function is a mean over as many terms as the size of the
248+ mini-batch times the sequence length, e.g. in the case of
249+ language modeling. If `None`, the loss function is a sum. This
250+ argument is used to ensure that the preconditioner is scaled
251+ consistently with the loss and the gradient. Default: `"batch"`.
219252 scaling: An additional scaling that will be applied to the gradient.
220253 kfac_approx: The KFAC approximation to use for linear weight-sharing
221254 layers. Possible values are `"expand"` and `"reduce"`.
@@ -224,14 +257,21 @@ def linear_process_grad_output(
224257 The processed gradient. Has shape `[batch_size, d_out]` for `"reduce"`
225258 and `[batch_size * ..., d_out]` for `"expand"`.
226259 """
260+ # We have to adjust the scaling to account for the mean reduction of the
261+ # loss used for computing the gradients when loss_average is not None.
262+ if loss_average is not None :
263+ num_loss_terms = g .shape [0 ] # batch_size
264+ if loss_average == "batch+sequence" :
265+ # Size of all weight-sharing dimensions.
266+ num_loss_terms *= g .shape [1 :- 1 ].numel ()
267+
268+ scaling *= sqrt (num_loss_terms )
269+
227270 if kfac_approx == "expand" :
228271 # KFAC-expand approximation
229272 g = rearrange (g , "b ... d_out -> (b ...) d_out" )
230273 else :
231274 # KFAC-reduce approximation
232275 g = reduce (g , "b ... d_out -> b d_out" , "sum" )
233276
234- # The use of `g.shape[0]` assumes that the setting of the loss, i.e. the
235- # number of loss terms, matches the `kfac_approx` that is used.
236- scaling = scaling * sqrt (g .shape [0 ]) if batch_averaged else scaling
237277 return g * scaling
0 commit comments