Skip to content

Commit 5a69057

Browse files
author
Seppo Enarvi
committed
WeightAveraging calls the configure_model hook but issues a warning
1 parent 5deb0bb commit 5a69057

File tree

2 files changed

+51
-2
lines changed

2 files changed

+51
-2
lines changed

src/lightning/pytorch/callbacks/weight_averaging.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
import lightning.pytorch as pl
2828
from lightning.pytorch.callbacks.callback import Callback
29+
from lightning.pytorch.utilities.model_helpers import is_overridden
2930
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn
3031
from lightning.pytorch.utilities.types import STEP_OUTPUT
3132

@@ -55,6 +56,13 @@ class WeightAveraging(Callback):
5556
See also the documentation on the :ref:`weight averaging callbacks <advanced/training_tricks:Weight Averaging>`
5657
provided by Lightning.
5758
59+
Note:
60+
To ensure that the :class:`AveragedModel` will contain all layers,
61+
:meth:`~lightning.pytorch.callbacks.weight_averaging.WeightAveraging.setup` will call
62+
:meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` before instantiating the
63+
:class:`AveragedModel`. However, that hook is not called in a strategy aware context, sharded models do not work
64+
with weight averaging, and a warning will be issued.
65+
5866
Example::
5967
6068
from lightning.pytorch.callbacks import WeightAveraging
@@ -137,6 +145,16 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: s
137145
"""
138146
if stage == "fit":
139147
device = self._device or pl_module.device
148+
149+
# If the configure_model hook is overridden, call it to create the layers before constructing the
150+
# AveragedModel. However, sharding will not be done and a warning will be issued.
151+
if is_overridden("configure_model", pl_module):
152+
rank_zero_warn(
153+
"You're using the WeightAveraging callback with a model that overrides the configure_model "
154+
"callback. WeightAveraging doesn't support sharding model layers, so you may run out of memory."
155+
)
156+
pl_module.configure_model()
157+
140158
self._average_model = AveragedModel(
141159
model=pl_module, device=device, use_buffers=self._use_buffers, **self._kwargs
142160
)

tests/tests_pytorch/callbacks/test_weight_averaging.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,19 @@ def configure_optimizers(self) -> None:
4747
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
4848

4949

50+
class LargeTestModel(BoringModel):
51+
def __init__(self):
52+
super().__init__()
53+
self.layer = None
54+
55+
def configure_model(self):
56+
print("XXX configure_model")
57+
self.layer = nn.Sequential(nn.Linear(32, 32), nn.ReLU(), nn.Linear(32, 2))
58+
59+
def configure_optimizers(self):
60+
return torch.optim.SGD(self.parameters(), lr=0.01)
61+
62+
5063
class EMAAveragingFunction:
5164
"""EMA averaging function.
5265
@@ -252,8 +265,26 @@ def test_swa(tmp_path):
252265
_train(model, dataset, tmp_path, SWATestCallback())
253266

254267

268+
@pytest.mark.parametrize(
269+
("strategy", "accelerator", "devices"),
270+
[
271+
("auto", "cpu", 1),
272+
pytest.param("auto", "gpu", 1, marks=RunIf(min_cuda_gpus=1)),
273+
pytest.param("fsdp", "gpu", 1, marks=RunIf(min_cuda_gpus=1)),
274+
pytest.param("ddp", "gpu", 2, marks=RunIf(min_cuda_gpus=2)),
275+
pytest.param("fsdp", "gpu", 2, marks=RunIf(min_cuda_gpus=2)),
276+
],
277+
)
278+
def test_ema_configure_model(tmp_path, strategy, accelerator, devices):
279+
model = LargeTestModel()
280+
dataset = RandomDataset(32, 32)
281+
callback = EMATestCallback()
282+
_train(model, dataset, tmp_path, callback, strategy=strategy, accelerator=accelerator, devices=devices)
283+
assert isinstance(callback._average_model.module.layer, nn.Sequential)
284+
285+
255286
def _train(
256-
model: TestModel,
287+
model: BoringModel,
257288
dataset: Dataset,
258289
tmp_path: str,
259290
callback: WeightAveraging,
@@ -262,7 +293,7 @@ def _train(
262293
devices: int = 1,
263294
checkpoint_path: Optional[str] = None,
264295
will_crash: bool = False,
265-
) -> TestModel:
296+
) -> None:
266297
deterministic = accelerator == "cpu"
267298
trainer = Trainer(
268299
accelerator=accelerator,

0 commit comments

Comments
 (0)