|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | import os |
| 15 | +from copy import deepcopy |
15 | 16 | from pathlib import Path |
16 | 17 | from typing import Any, Optional |
17 | 18 |
|
18 | 19 | import pytest |
19 | 20 | import torch |
20 | 21 | from torch import Tensor, nn |
21 | 22 | from torch.optim.swa_utils import get_swa_avg_fn |
22 | | -from torch.utils.data import DataLoader |
| 23 | +from torch.utils.data import DataLoader, Dataset |
23 | 24 |
|
24 | 25 | from lightning.pytorch import LightningModule, Trainer |
25 | 26 | from lightning.pytorch.callbacks import WeightAveraging |
26 | 27 | from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset |
27 | 28 | from tests_pytorch.helpers.runif import RunIf |
28 | 29 |
|
29 | 30 |
|
30 | | -class WeightAveragingTestModel(BoringModel): |
31 | | - def __init__( |
32 | | - self, batch_norm: bool = True, iterable_dataset: bool = False, crash_on_epoch: Optional[int] = None |
33 | | - ) -> None: |
| 31 | +class TestModel(BoringModel): |
| 32 | + def __init__(self, batch_norm: bool = True) -> None: |
34 | 33 | super().__init__() |
35 | 34 | layers = [nn.Linear(32, 32)] |
36 | 35 | if batch_norm: |
37 | 36 | layers.append(nn.BatchNorm1d(32)) |
38 | 37 | layers += [nn.ReLU(), nn.Linear(32, 2)] |
39 | 38 | self.layer = nn.Sequential(*layers) |
40 | | - self.iterable_dataset = iterable_dataset |
41 | | - self.crash_on_epoch = crash_on_epoch |
| 39 | + self.crash_on_epoch = None |
42 | 40 |
|
43 | 41 | def training_step(self, batch: Tensor, batch_idx: int) -> None: |
44 | 42 | if self.crash_on_epoch and self.trainer.current_epoch >= self.crash_on_epoch: |
45 | | - raise Exception("CRASH TEST") |
| 43 | + raise Exception("CRASH") |
46 | 44 | return super().training_step(batch, batch_idx) |
47 | 45 |
|
48 | | - def train_dataloader(self) -> None: |
49 | | - dataset_class = RandomIterableDataset if self.iterable_dataset else RandomDataset |
50 | | - return DataLoader(dataset_class(32, 32), batch_size=4) |
51 | | - |
52 | 46 | def configure_optimizers(self) -> None: |
53 | 47 | return torch.optim.SGD(self.layer.parameters(), lr=0.1) |
54 | 48 |
|
@@ -194,95 +188,115 @@ def setup(self, trainer, pl_module, stage) -> None: |
194 | 188 | @pytest.mark.parametrize("batch_norm", [True, False]) |
195 | 189 | @pytest.mark.parametrize("iterable_dataset", [True, False]) |
196 | 190 | def test_ema(tmp_path, batch_norm: bool, iterable_dataset: bool): |
197 | | - _train(tmp_path, EMATestCallback(), batch_norm=batch_norm, iterable_dataset=iterable_dataset) |
| 191 | + model = TestModel(batch_norm=batch_norm) |
| 192 | + dataset = RandomIterableDataset(32, 32) if iterable_dataset else RandomDataset(32, 32) |
| 193 | + _train(model, dataset, tmp_path, EMATestCallback()) |
198 | 194 |
|
199 | 195 |
|
200 | 196 | @pytest.mark.parametrize( |
201 | 197 | "accelerator", [pytest.param("gpu", marks=RunIf(min_cuda_gpus=1)), pytest.param("mps", marks=RunIf(mps=True))] |
202 | 198 | ) |
203 | 199 | def test_ema_accelerator(tmp_path, accelerator): |
204 | | - _train(tmp_path, EMATestCallback(), accelerator=accelerator, devices=1) |
| 200 | + model = TestModel() |
| 201 | + dataset = RandomDataset(32, 32) |
| 202 | + _train(model, dataset, tmp_path, EMATestCallback(), accelerator=accelerator, devices=1) |
205 | 203 |
|
206 | 204 |
|
207 | 205 | @RunIf(min_cuda_gpus=2, standalone=True) |
208 | 206 | def test_ema_ddp(tmp_path): |
209 | | - _train(tmp_path, EMATestCallback(devices=2), strategy="ddp", accelerator="gpu", devices=2) |
| 207 | + model = TestModel() |
| 208 | + dataset = RandomDataset(32, 32) |
| 209 | + _train(model, dataset, tmp_path, EMATestCallback(devices=2), strategy="ddp", accelerator="gpu", devices=2) |
210 | 210 |
|
211 | 211 |
|
212 | 212 | @RunIf(min_cuda_gpus=2) |
213 | 213 | def test_ema_ddp_spawn(tmp_path): |
214 | | - _train(tmp_path, EMATestCallback(devices=2), strategy="ddp_spawn", accelerator="gpu", devices=2) |
| 214 | + model = TestModel() |
| 215 | + dataset = RandomDataset(32, 32) |
| 216 | + _train(model, dataset, tmp_path, EMATestCallback(devices=2), strategy="ddp_spawn", accelerator="gpu", devices=2) |
215 | 217 |
|
216 | 218 |
|
217 | 219 | @RunIf(skip_windows=True) |
218 | 220 | def test_ema_ddp_spawn_cpu(tmp_path): |
219 | | - _train(tmp_path, EMATestCallback(devices=2), strategy="ddp_spawn", accelerator="cpu", devices=2) |
| 221 | + model = TestModel() |
| 222 | + dataset = RandomDataset(32, 32) |
| 223 | + _train(model, dataset, tmp_path, EMATestCallback(devices=2), strategy="ddp_spawn", accelerator="cpu", devices=2) |
220 | 224 |
|
221 | 225 |
|
222 | | -@pytest.mark.parametrize("crash_on_epoch", [1, 3]) |
| 226 | +@pytest.mark.parametrize("crash_on_epoch", [1, 3, 5]) |
223 | 227 | def test_ema_resume(tmp_path, crash_on_epoch): |
224 | | - _train_and_resume(tmp_path, crash_on_epoch=crash_on_epoch) |
| 228 | + dataset = RandomDataset(32, 32) |
| 229 | + model1 = TestModel() |
| 230 | + model2 = deepcopy(model1) |
| 231 | + |
| 232 | + _train(model1, dataset, tmp_path, EMATestCallback()) |
| 233 | + |
| 234 | + model2.crash_on_epoch = crash_on_epoch |
| 235 | + model2 = _train_and_resume(model2, dataset, tmp_path) |
| 236 | + |
| 237 | + for param1, param2 in zip(model1.parameters(), model2.parameters()): |
| 238 | + assert torch.allclose(param1, param2, atol=0.001) |
225 | 239 |
|
226 | 240 |
|
227 | 241 | @RunIf(skip_windows=True) |
228 | 242 | def test_ema_resume_ddp(tmp_path): |
229 | | - _train_and_resume(tmp_path, crash_on_epoch=3, use_ddp=True) |
| 243 | + model = TestModel() |
| 244 | + model.crash_on_epoch = 3 |
| 245 | + dataset = RandomDataset(32, 32) |
| 246 | + _train_and_resume(model, dataset, tmp_path, strategy="ddp_spawn", devices=2) |
230 | 247 |
|
231 | 248 |
|
232 | 249 | def test_swa(tmp_path): |
233 | | - _train(tmp_path, SWATestCallback()) |
| 250 | + model = TestModel() |
| 251 | + dataset = RandomDataset(32, 32) |
| 252 | + _train(model, dataset, tmp_path, SWATestCallback()) |
234 | 253 |
|
235 | 254 |
|
236 | 255 | def _train( |
| 256 | + model: TestModel, |
| 257 | + dataset: Dataset, |
237 | 258 | tmp_path: str, |
238 | 259 | callback: WeightAveraging, |
239 | | - batch_norm: bool = True, |
240 | 260 | strategy: str = "auto", |
241 | 261 | accelerator: str = "cpu", |
242 | 262 | devices: int = 1, |
243 | | - iterable_dataset: bool = False, |
244 | 263 | checkpoint_path: Optional[str] = None, |
245 | | - crash_on_epoch: Optional[int] = None, |
246 | | -) -> None: |
| 264 | + will_crash: bool = False, |
| 265 | +) -> TestModel: |
| 266 | + deterministic = accelerator == "cpu" |
247 | 267 | trainer = Trainer( |
248 | | - default_root_dir=tmp_path, |
249 | | - enable_progress_bar=False, |
250 | | - enable_model_summary=False, |
| 268 | + accelerator=accelerator, |
| 269 | + strategy=strategy, |
| 270 | + devices=devices, |
251 | 271 | logger=False, |
| 272 | + callbacks=callback, |
252 | 273 | max_epochs=8, |
253 | 274 | num_sanity_val_steps=0, |
254 | | - callbacks=callback, |
| 275 | + enable_checkpointing=will_crash, |
| 276 | + enable_progress_bar=False, |
| 277 | + enable_model_summary=False, |
255 | 278 | accumulate_grad_batches=2, |
256 | | - strategy=strategy, |
257 | | - accelerator=accelerator, |
258 | | - devices=devices, |
259 | | - ) |
260 | | - model = WeightAveragingTestModel( |
261 | | - batch_norm=batch_norm, iterable_dataset=iterable_dataset, crash_on_epoch=crash_on_epoch |
| 279 | + deterministic=deterministic, |
| 280 | + default_root_dir=tmp_path, |
262 | 281 | ) |
263 | | - |
264 | | - if crash_on_epoch is None: |
265 | | - trainer.fit(model, ckpt_path=checkpoint_path) |
| 282 | + dataloader = DataLoader(dataset, batch_size=4, shuffle=False) |
| 283 | + if will_crash: |
| 284 | + with pytest.raises(Exception, match="CRASH"): |
| 285 | + trainer.fit(model, dataloader, ckpt_path=checkpoint_path) |
266 | 286 | else: |
267 | | - with pytest.raises(Exception, match="CRASH TEST"): |
268 | | - trainer.fit(model, ckpt_path=checkpoint_path) |
269 | | - |
| 287 | + trainer.fit(model, dataloader, ckpt_path=checkpoint_path) |
270 | 288 | assert trainer.lightning_module == model |
271 | 289 |
|
272 | 290 |
|
273 | | -def _train_and_resume(tmp_path: str, crash_on_epoch: int, use_ddp: bool = False) -> None: |
274 | | - strategy = "ddp_spawn" if use_ddp else "auto" |
275 | | - devices = 2 if use_ddp else 1 |
276 | | - |
277 | | - _train( |
278 | | - tmp_path, EMATestCallback(devices=devices), strategy=strategy, devices=devices, crash_on_epoch=crash_on_epoch |
279 | | - ) |
| 291 | +def _train_and_resume(model: TestModel, dataset: Dataset, tmp_path: str, devices: int = 1, **kwargs) -> TestModel: |
| 292 | + _train(model, dataset, tmp_path, EMATestCallback(devices=devices), devices=devices, will_crash=True, **kwargs) |
280 | 293 |
|
281 | 294 | checkpoint_dir = Path(tmp_path) / "checkpoints" |
282 | 295 | checkpoint_names = os.listdir(checkpoint_dir) |
283 | 296 | assert len(checkpoint_names) == 1 |
284 | 297 | checkpoint_path = str(checkpoint_dir / checkpoint_names[0]) |
285 | 298 |
|
286 | | - _train( |
287 | | - tmp_path, EMATestCallback(devices=devices), strategy=strategy, devices=devices, checkpoint_path=checkpoint_path |
288 | | - ) |
| 299 | + model = TestModel.load_from_checkpoint(checkpoint_path) |
| 300 | + callback = EMATestCallback(devices=devices) |
| 301 | + _train(model, dataset, tmp_path, callback, devices=devices, checkpoint_path=checkpoint_path, **kwargs) |
| 302 | + return model |
0 commit comments