Skip to content

Commit a6beac9

Browse files
justusschocklexierule
authored andcommitted
Bugfix/swa iterable dset (#8172)
* add test * add fix * Update CHANGELOG.md
1 parent a04a30a commit a6beac9

File tree

2 files changed

+17
-8
lines changed

2 files changed

+17
-8
lines changed

pytorch_lightning/callbacks/stochastic_weight_avg.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,8 @@ def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningMo
222222
trainer.num_training_batches += 1
223223
trainer.train_loop._skip_backward = True
224224
self._accumulate_grad_batches = trainer.accumulate_grad_batches
225-
trainer.accumulate_grad_batches = len(trainer.train_dataloader)
225+
226+
trainer.accumulate_grad_batches = trainer.num_training_batches
226227

227228
def on_train_epoch_end(self, trainer: 'pl.Trainer', *args):
228229
trainer.train_loop._skip_backward = False

tests/callbacks/test_stochastic_weight_avg.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from pytorch_lightning import Trainer
2424
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_6
2525
from pytorch_lightning.utilities.exceptions import MisconfigurationException
26-
from tests.helpers import BoringModel, RandomDataset
26+
from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset
2727
from tests.helpers.runif import RunIf
2828

2929
if _TORCH_GREATER_EQUAL_1_6:
@@ -33,22 +33,27 @@
3333

3434
class SwaTestModel(BoringModel):
3535

36-
def __init__(self, batchnorm: bool = True, interval: str = "epoch"):
36+
def __init__(self, batchnorm: bool = True, interval: str = "epoch", iterable_dataset: bool = False):
3737
super().__init__()
3838
layers = [nn.Linear(32, 32)]
3939
if batchnorm:
4040
layers.append(nn.BatchNorm1d(32))
4141
layers += [nn.ReLU(), nn.Linear(32, 2)]
4242
self.layer = nn.Sequential(*layers)
4343
self.interval = interval
44+
self.iterable_dataset = iterable_dataset
4445

4546
def training_step(self, batch, batch_idx):
4647
output = self.forward(batch)
4748
loss = self.loss(batch, output)
4849
return {"loss": loss}
4950

5051
def train_dataloader(self):
51-
return DataLoader(RandomDataset(32, 64), batch_size=2)
52+
53+
dset_cls = RandomIterableDataset if self.iterable_dataset else RandomDataset
54+
dset = dset_cls(32, 64)
55+
56+
return DataLoader(dset, batch_size=2)
5257

5358
def configure_optimizers(self):
5459
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
@@ -107,8 +112,10 @@ def on_train_end(self, trainer, pl_module):
107112

108113

109114
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
110-
def train_with_swa(tmpdir, batchnorm=True, accelerator=None, gpus=None, num_processes=1, interval="epoch"):
111-
model = SwaTestModel(batchnorm=batchnorm, interval=interval)
115+
def train_with_swa(
116+
tmpdir, batchnorm=True, accelerator=None, gpus=None, num_processes=1, interval="epoch", iterable_dataset=False
117+
):
118+
model = SwaTestModel(batchnorm=batchnorm, interval=interval, iterable_dataset=iterable_dataset)
112119
swa_start = 2
113120
max_epochs = 5
114121
swa_callback = SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1)
@@ -154,8 +161,9 @@ def test_swa_callback_1_gpu(tmpdir):
154161

155162
@RunIf(min_torch="1.6.0")
156163
@pytest.mark.parametrize("batchnorm", (True, False))
157-
def test_swa_callback(tmpdir, batchnorm: bool):
158-
train_with_swa(tmpdir, batchnorm=batchnorm)
164+
@pytest.mark.parametrize('iterable_dataset', (True, False))
165+
def test_swa_callback(tmpdir, batchnorm: bool, iterable_dataset: bool):
166+
train_with_swa(tmpdir, batchnorm=batchnorm, iterable_dataset=iterable_dataset)
159167

160168

161169
@RunIf(min_torch="1.6.0")

0 commit comments

Comments
 (0)