Skip to content

Commit 8399a6c

Browse files
authored
[REF] Refactor batch_averaged argument (#44)
* Refactor batch_averaged argument * Fix typo * Remove superfluous slicing * Refactor scaling for batch_averaged * Add SUPPORTED_LOSS_AVERAGE class attribute * Rename batch_averaged to loss_average * Fix docstrings
1 parent 16640b6 commit 8399a6c

13 files changed

+128
-57
lines changed

docs/examples/example_03_param_groups.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
"momentum": 0.9,
6161
"weight_decay": 1e-2,
6262
"lr_cov": 1e-2,
63-
"batch_averaged": True,
63+
"loss_average": "batch",
6464
"T": 1,
6565
"alpha1": 0.5,
6666
}

docs/examples/example_04_advanced.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@
9393
"momentum": 0.9,
9494
"weight_decay": 1e-2,
9595
"lr_cov": 1e-2,
96-
"batch_averaged": True,
96+
"loss_average": "batch",
9797
"T": 1,
9898
"alpha1": 0.5,
9999
"structures": ("dense", "dense"),

singd/optim/optimizer.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class SINGD(Optimizer):
4949
https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.state_dict.html) and
5050
[`.load_state_dict()`](\
5151
https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.load_state_dict.html)).
52+
SUPPORTED_LOSS_AVERAGE: Supported loss averaging schemes.
5253
_step_supports_amp_scaling: Indicates that `step` handles gradient scaling
5354
internally if the optimizer is used together with a
5455
[`torch.cuda.amp.GradScaler`](\
@@ -69,6 +70,11 @@ class SINGD(Optimizer):
6970
"triutoeplitz": TriuToeplitzMatrix,
7071
}
7172
SUPPORTED_MODULES: Tuple[Type[Module], ...] = (Linear, Conv2d)
73+
SUPPORTED_LOSS_AVERAGE: Tuple[Union[None, str], ...] = (
74+
None,
75+
"batch",
76+
"batch+sequence",
77+
)
7278
_step_supports_amp_scaling = True # do not modify this name (PyTorch convention)!
7379

7480
def __init__(
@@ -81,7 +87,7 @@ def __init__(
8187
alpha1: float = 0.5, # α₁ in the paper
8288
weight_decay: float = 0.0, # γ in the paper
8389
T: int = 10, # T in the paper
84-
batch_averaged: bool = True,
90+
loss_average: Union[None, str] = "batch",
8591
lr_cov: Union[float, Callable[[int], float]] = 1e-2, # β₁ in the paper
8692
structures: Tuple[str, str] = ("dense", "dense"),
8793
kfac_approx: str = "expand",
@@ -121,8 +127,15 @@ def __init__(
121127
weight_decay: (\\(\\gamma\\) in the paper) Weight decay on the parameters.
122128
Default: `0.0`.
123129
T: Pre-conditioner update frequency. Default: `10`.
124-
batch_averaged: Whether the loss function is a mean over per-sample
125-
losses. Default is `True`. If `False `, the loss function is a sum.
130+
loss_average: Whether the loss function is a mean over per-sample
131+
losses and if yes, over which dimensions the mean is taken.
132+
If `"batch"`, the loss function is a mean over as many terms as
133+
the size of the mini-batch. If `"batch+sequence"`, the loss
134+
function is a mean over as many terms as the size of the
135+
mini-batch times the sequence length, e.g. in the case of
136+
language modeling. If `None`, the loss function is a sum. This
137+
argument is used to ensure that the preconditioner is scaled
138+
consistently with the loss and the gradient. Default: `"batch"`.
126139
lr_cov: (β₁ in the paper) Learning rate for the updates of the pre-
127140
conditioner momenta \\(\\mathbf{m}_\\mathbf{K}\\) and
128141
\\(\\mathbf{m}_\\mathbf{C}\\). Default is `1e-2`. Also allows for a
@@ -205,7 +218,7 @@ def __init__(
205218
alpha1=alpha1,
206219
weight_decay=weight_decay,
207220
T=T,
208-
batch_averaged=batch_averaged,
221+
loss_average=loss_average,
209222
lr_cov=lr_cov,
210223
structures=structures,
211224
kfac_approx=kfac_approx,
@@ -280,6 +293,8 @@ def _check_param_groups(self, model: Module) -> Dict[int, int]:
280293
ValueError: If `kfac_approx` for any param group is not
281294
`'expand'` or `'reduce'`.
282295
ValueError: If parameters in a supported layer are in different groups.
296+
ValueError: If `loss_average` for any param group is not in
297+
self.SUPPORTED_LOSS_AVERAGE.
283298
284299
Returns:
285300
A dictionary mapping parameter IDs (`.data_ptr()`) to group indices.
@@ -298,6 +313,12 @@ def _check_param_groups(self, model: Module) -> Dict[int, int]:
298313
"kfac_approx has to be set to either 'expand' or 'reduce', "
299314
f"but was set to {group['kfac_approx']}."
300315
)
316+
if group["loss_average"] not in self.SUPPORTED_LOSS_AVERAGE:
317+
raise ValueError(
318+
"loss_average has to be set to one out of "
319+
f"{self.SUPPORTED_LOSS_AVERAGE}, but was set to "
320+
f"{group['loss_average']}."
321+
)
301322

302323
# Find out which parameter is in which group
303324
param_to_group_idx = {}
@@ -551,7 +572,7 @@ def _accumulate_H_terms(
551572
if self.steps % T != 0:
552573
return
553574

554-
batch_averaged = self._get_param_group_entry(module, "batch_averaged")
575+
loss_average = self._get_param_group_entry(module, "loss_average")
555576
kfac_approx = self._get_param_group_entry(module, "kfac_approx")
556577
module_name = self.module_names[module]
557578

@@ -563,7 +584,7 @@ def _accumulate_H_terms(
563584

564585
g = grad_output[0].data
565586
# Process into matrix according to kfac_approx, add scaling from batch average
566-
g = process_grad_output(g, module, batch_averaged, kfac_approx)
587+
g = process_grad_output(g, module, loss_average, kfac_approx)
567588

568589
# 2) Update H_K, H_C
569590
K, C = self.Ks[module_name], self.Cs[module_name]
@@ -583,7 +604,7 @@ def _accumulate_H_terms(
583604
# If DDP is used.
584605
if dist.is_initialized():
585606
# all-reduce across devices (computes average by default).
586-
op = dist.ReduceOp.AVG if batch_averaged else dist.ReduceOp.SUM
607+
op = dist.ReduceOp.AVG if loss_average else dist.ReduceOp.SUM
587608
H_K.all_reduce(op=op)
588609
H_C.all_reduce(op=op)
589610

@@ -672,8 +693,8 @@ def _compute_natural_gradient(self, module: Module) -> Tuple[Tensor, ...]:
672693
# If DDP is used.
673694
if dist.is_initialized():
674695
# all-reduce across devices.
675-
batch_averaged = self._get_param_group_entry(module, "batch_averaged")
676-
op = dist.ReduceOp.AVG if batch_averaged else dist.ReduceOp.SUM
696+
loss_average = self._get_param_group_entry(module, "loss_average")
697+
op = dist.ReduceOp.AVG if loss_average else dist.ReduceOp.SUM
677698
dist.all_reduce(nat_grad, op=op)
678699

679700
# 3) UN-CONCATENATE, UN-RESHAPE, AND COPY THE NATURAL GRADIENT TO `.GRAD`

singd/optim/utils.py

Lines changed: 56 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -141,47 +141,69 @@ def linear_process_input(x: Tensor, layer: Linear, kfac_approx: str) -> Tensor:
141141

142142

143143
def process_grad_output(
144-
grad_output: Tensor, module: Module, batch_averaged: bool, kfac_approx: str
144+
grad_output: Tensor,
145+
module: Module,
146+
loss_average: Union[None, str],
147+
kfac_approx: str,
145148
) -> Tensor:
146149
"""Reshape output gradients into matrices and apply scaling.
147150
148151
Args:
149152
grad_output: The gradient w.r.t. the output of the module.
150153
module: The module.
151-
batch_averaged: Whether the loss is a mean over per-sample losses.
154+
loss_average: Whether the loss function is a mean over per-sample
155+
losses and if yes, over which dimensions the mean is taken.
156+
If `"batch"`, the loss function is a mean over as many terms as
157+
the size of the mini-batch. If `"batch+sequence"`, the loss
158+
function is a mean over as many terms as the size of the
159+
mini-batch times the sequence length, e.g. in the case of
160+
language modeling. If `None`, the loss function is a sum. This
161+
argument is used to ensure that the preconditioner is scaled
162+
consistently with the loss and the gradient. Default: `"batch"`.
152163
kfac_approx: The KFAC approximation to use for linear weight-sharing
153164
layers. Possible values are `"expand"` and `"reduce"`.
154165
155166
Returns:
156167
The processed output gradient.
157168
158169
Raises:
170+
AssertionError: If `loss_average` is not `None`, `"batch"`, or
171+
`"batch+sequence"`.
159172
AssertionError: If `kfac_approx` is neither `"expand"` nor `"reduce"`.
160173
NotImplementedError: If the module is not supported.
161174
"""
175+
assert loss_average in {None, "batch", "batch+sequence"}
162176
assert kfac_approx in {"expand", "reduce"}
163177
grad_scaling = 1.0
164178
if isinstance(module, Conv2d):
165179
return conv2d_process_grad_output(
166-
grad_output, batch_averaged, grad_scaling, kfac_approx
180+
grad_output, loss_average, grad_scaling, kfac_approx
167181
)
168182
elif isinstance(module, Linear):
169183
return linear_process_grad_output(
170-
grad_output, batch_averaged, grad_scaling, kfac_approx
184+
grad_output, loss_average, grad_scaling, kfac_approx
171185
)
172186
else:
173187
raise NotImplementedError(f"Can't process grad_output for {module}.")
174188

175189

176190
def conv2d_process_grad_output(
177-
g: Tensor, batch_averaged: bool, scaling: float, kfac_approx: str
191+
g: Tensor, loss_average: Union[None, str], scaling: float, kfac_approx: str
178192
) -> Tensor:
179193
"""Process the output gradient of a convolution before the self-inner product.
180194
181195
Args:
182196
g: Gradient w.r.t. the output of a convolution. Has shape
183197
`[batch_size, C_out, O1, O2]`.
184-
batch_averaged: Whether to multiply with the batch size.
198+
loss_average: Whether the loss function is a mean over per-sample
199+
losses and if yes, over which dimensions the mean is taken.
200+
If `"batch"`, the loss function is a mean over as many terms as
201+
the size of the mini-batch. If `"batch+sequence"`, the loss
202+
function is a mean over as many terms as the size of the
203+
mini-batch times the sequence length, e.g. in the case of
204+
language modeling. If `None`, the loss function is a sum. This
205+
argument is used to ensure that the preconditioner is scaled
206+
consistently with the loss and the gradient. Default: `"batch"`.
185207
scaling: An additional scaling that will be applied to the gradient.
186208
kfac_approx: The KFAC approximation to use. Possible values are
187209
`"expand"` and `"reduce"`.
@@ -190,11 +212,14 @@ def conv2d_process_grad_output(
190212
The processed scaled gradient. Has shape `[batch_size, C_out]` for
191213
`"reduce"` and `[batch_size * O1 * O2, C_out]` for `"expand"`.
192214
"""
193-
# The scaling by `sqrt(batch_size)` when `batch_averaged=True` assumes
194-
# that we are in the reduce setting, i.e. the number of loss terms equals
195-
# the batch size.
196-
batch_size = g.shape[0]
197-
scaling = scaling * sqrt(batch_size) if batch_averaged else scaling
215+
# We have to adjust the scaling to account for the mean reduction of the
216+
# loss used for computing the gradients when loss_average is not None.
217+
if loss_average is not None:
218+
num_loss_terms = g.shape[0] # batch_size
219+
if loss_average == "batch+sequence":
220+
num_loss_terms *= g.shape[2:].numel() # spatial size = O1 * O2
221+
222+
scaling *= sqrt(num_loss_terms)
198223

199224
if kfac_approx == "expand":
200225
# KFAC-expand approximation
@@ -207,15 +232,23 @@ def conv2d_process_grad_output(
207232

208233

209234
def linear_process_grad_output(
210-
g: Tensor, batch_averaged: bool, scaling: float, kfac_approx: str
235+
g: Tensor, loss_average: Union[None, str], scaling: float, kfac_approx: str
211236
) -> Tensor:
212237
"""Process the output gradient of a linear layer before the self-inner product.
213238
214239
Args:
215240
g: Gradient w.r.t. the output of a linear layer. Has shape
216241
`[batch_size, ..., d_out]` where `...` is an arbitrary number of
217242
weight-shared dimensions.
218-
batch_averaged: Whether to multiply with the batch size.
243+
loss_average: Whether the loss function is a mean over per-sample
244+
losses and if yes, over which dimensions the mean is taken.
245+
If `"batch"`, the loss function is a mean over as many terms as
246+
the size of the mini-batch. If `"batch+sequence"`, the loss
247+
function is a mean over as many terms as the size of the
248+
mini-batch times the sequence length, e.g. in the case of
249+
language modeling. If `None`, the loss function is a sum. This
250+
argument is used to ensure that the preconditioner is scaled
251+
consistently with the loss and the gradient. Default: `"batch"`.
219252
scaling: An additional scaling that will be applied to the gradient.
220253
kfac_approx: The KFAC approximation to use for linear weight-sharing
221254
layers. Possible values are `"expand"` and `"reduce"`.
@@ -224,14 +257,21 @@ def linear_process_grad_output(
224257
The processed gradient. Has shape `[batch_size, d_out]` for `"reduce"`
225258
and `[batch_size * ..., d_out]` for `"expand"`.
226259
"""
260+
# We have to adjust the scaling to account for the mean reduction of the
261+
# loss used for computing the gradients when loss_average is not None.
262+
if loss_average is not None:
263+
num_loss_terms = g.shape[0] # batch_size
264+
if loss_average == "batch+sequence":
265+
# Size of all weight-sharing dimensions.
266+
num_loss_terms *= g.shape[1:-1].numel()
267+
268+
scaling *= sqrt(num_loss_terms)
269+
227270
if kfac_approx == "expand":
228271
# KFAC-expand approximation
229272
g = rearrange(g, "b ... d_out -> (b ...) d_out")
230273
else:
231274
# KFAC-reduce approximation
232275
g = reduce(g, "b ... d_out -> b d_out", "sum")
233276

234-
# The use of `g.shape[0]` assumes that the setting of the loss, i.e. the
235-
# number of loss terms, matches the `kfac_approx` that is used.
236-
scaling = scaling * sqrt(g.shape[0]) if batch_averaged else scaling
237277
return g * scaling

test/optim/test_autocast.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def test_autocast():
4848
"momentum": 0.9,
4949
"weight_decay": 1e-2,
5050
"lr_cov": 1e-2,
51-
"batch_averaged": True,
51+
"loss_average": "batch",
5252
"T": 1,
5353
"alpha1": 0.5,
5454
"structures": ("dense", "dense"),

test/optim/test_checkpointing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def setup() -> Tuple[Sequential, Module, SINGD]:
4242
"momentum": 0.9,
4343
"weight_decay": 1e-2,
4444
"lr_cov": 1e-2,
45-
"batch_averaged": True,
45+
"loss_average": "batch",
4646
"T": 1,
4747
"alpha1": 0.5,
4848
"structures": ("dense", "dense"),

test/optim/test_gradient_accumulation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,14 @@ def test_gradient_accumulation(reduction: str):
5454
loss_func_mini = CrossEntropyLoss(reduction=reduction)
5555
loss_func_micro = deepcopy(loss_func_mini)
5656

57-
batch_averaged = {"mean": True, "sum": False}[reduction]
57+
loss_average = {"mean": "batch", "sum": None}[reduction]
5858
optim_hyperparams = {
5959
"lr": 5e-4,
6060
"damping": 1e-4,
6161
"momentum": 0.9,
6262
"weight_decay": 1e-2,
6363
"lr_cov": 1e-2,
64-
"batch_averaged": batch_averaged,
64+
"loss_average": loss_average,
6565
"T": 1,
6666
"alpha1": 0.5,
6767
"structures": ("dense", "dense"),

0 commit comments

Comments
 (0)