Skip to content

Commit fa6720a

Browse files
tchatonlexierule
authored andcommitted
[Bug] Add SharedCycleIteratorState (#8889)
1 parent 55c375b commit fa6720a

File tree

3 files changed

+136
-7
lines changed

3 files changed

+136
-7
lines changed

CHANGELOG.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file.
44

55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

7+
8+
## [1.4.3] - 2021-08-17
9+
10+
- Fixed infinite loop with CycleIterator and multiple loaders ([#8889](https://github.com/PyTorchLightning/pytorch-lightning/pull/8889))
11+
712
## [1.4.2] - 2021-08-10
813

914
- Fixed recursive call for `apply_to_collection(include_none=False)` ([#8719](https://github.com/PyTorchLightning/pytorch-lightning/pull/8719))
@@ -27,6 +32,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2732
- Fixed plateau scheduler stepping on incomplete epoch ([#8861](https://github.com/PyTorchLightning/pytorch-lightning/pull/8861))
2833

2934

35+
- Fixed infinite loop with CycleIterator and multiple loaders ([#8889](https://github.com/PyTorchLightning/pytorch-lightning/pull/8889))
36+
37+
3038
## [1.4.0] - 2021-07-27
3139

3240
### Added

pytorch_lightning/trainer/supporters.py

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414

1515
import os
1616
from collections.abc import Iterable, Iterator, Mapping, Sequence
17+
from dataclasses import dataclass, field
1718
from functools import partial
18-
from typing import Any, Callable, Dict, Generator, Optional, Tuple, Union
19+
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
1920

2021
import torch
2122
from torch import Tensor
@@ -170,12 +171,35 @@ def to_disk(self) -> None:
170171
torch.save(outputs, fp)
171172

172173

174+
@dataclass
175+
class SharedCycleIteratorState:
176+
177+
mode: str = "max_size_cycle"
178+
dataloaders: List[DataLoader] = field(default_factory=lambda: [])
179+
has_finished: Dict[int, bool] = field(default_factory=lambda: {})
180+
has_reset: bool = False
181+
182+
def reset(self) -> None:
183+
for dataloader in self.dataloaders:
184+
self.has_finished[id(dataloader)] = False
185+
self.has_reset = True
186+
187+
@property
188+
def done(self) -> bool:
189+
if not self.has_reset:
190+
raise MisconfigurationException("Please, call reset once all dataloaders have been added.")
191+
if len(self.dataloaders) == 1:
192+
return False
193+
decision_fn = all if self.mode == "max_size_cycle" else any
194+
return decision_fn(self.has_finished.values())
195+
196+
173197
class CycleIterator:
174198
"""
175199
Iterator for restarting a dataloader if it runs out of samples
176200
"""
177201

178-
def __init__(self, loader: Any, length: Optional[int] = None):
202+
def __init__(self, loader: Any, length: Optional[int] = None, state: SharedCycleIteratorState = None):
179203
"""
180204
Args:
181205
loader: the loader to restart for cyclic (and optionally infinite) sampling
@@ -185,6 +209,15 @@ def __init__(self, loader: Any, length: Optional[int] = None):
185209
if length is None:
186210
length = float("inf")
187211

212+
if not state:
213+
state = SharedCycleIteratorState()
214+
state.dataloaders.append(loader)
215+
state.reset()
216+
else:
217+
state.dataloaders.append(loader)
218+
219+
self.state = state
220+
188221
self.length = length
189222
self.loader = loader
190223
self._loader_iter = None
@@ -205,21 +238,27 @@ def __next__(self) -> Any:
205238
"""
206239
Fetches the next batch from internal dataloader and restarts
207240
it if necessary
208-
209241
Returns:
210242
Any: the resulting batch
211-
212243
Raises:
213244
StopIteration: if more then :attr:`length` batches have been returned
214245
"""
215246
# Note: if self.length is `inf`, then the iterator will never stop
216-
if self.counter >= self.__len__():
247+
if self.counter >= self.__len__() or self.state.done:
217248
raise StopIteration
218249

219250
try:
220251
return next(self._loader_iter)
221252

222253
except StopIteration:
254+
255+
# inform the shared state this loader has completed
256+
self.state.has_finished[id(self.loader)] = True
257+
258+
# check if iteration should be stopped.
259+
if self.state.done:
260+
raise StopIteration
261+
223262
self._loader_iter = iter(self.loader)
224263
return next(self._loader_iter)
225264

@@ -468,10 +507,14 @@ def _wrap_loaders_max_size_cycle(self) -> Any:
468507

469508
# multiple loaders
470509
if isinstance(self.loaders, (Sequence, Mapping)):
510+
state = SharedCycleIteratorState()
511+
471512
self.loaders = apply_to_collection(
472-
self.loaders, Iterable, CycleIterator, length=length, wrong_dtype=(Sequence, Mapping)
513+
self.loaders, Iterable, CycleIterator, length=length, state=state, wrong_dtype=(Sequence, Mapping)
473514
)
474515

516+
state.reset()
517+
475518
def __iter__(self) -> Any:
476519
"""
477520
Create and return an iterator, `CombinedLoaderIterator`, for the combined loader.

tests/trainer/test_supporters.py

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from torch.utils.data import DataLoader, TensorDataset
2121
from torch.utils.data.dataset import Dataset, IterableDataset
2222
from torch.utils.data.distributed import DistributedSampler
23-
from torch.utils.data.sampler import Sampler
23+
from torch.utils.data.sampler import Sampler, SequentialSampler
2424

2525
from pytorch_lightning import Trainer
2626
from pytorch_lightning.trainer.supporters import (
@@ -59,6 +59,7 @@ def test_tensor_running_accum_reset():
5959

6060
def test_cycle_iterator():
6161
"""Test the cycling function of `CycleIterator`"""
62+
6263
iterator = CycleIterator(range(100), 1000)
6364
assert len(iterator) == 1000
6465
for idx, item in enumerate(iterator):
@@ -216,6 +217,83 @@ def test_combined_loader_sequence_min_size():
216217
assert idx == len(combined_loader) - 1
217218

218219

220+
class TestIterableDataset(IterableDataset):
221+
def __init__(self, size: int = 10):
222+
self.size = size
223+
224+
def __iter__(self):
225+
self.sampler = SequentialSampler(range(self.size))
226+
self.sampler_iter = iter(self.sampler)
227+
return self
228+
229+
def __next__(self):
230+
return next(self.sampler_iter)
231+
232+
233+
@pytest.mark.parametrize("mode", ["min_size", "max_size_cycle"])
234+
@pytest.mark.parametrize("use_multiple_dataloaders", [False, True])
235+
def test_combined_loader_sequence_iterable_dataset(mode, use_multiple_dataloaders):
236+
"""Test `CombinedLoader` of mode 'min_size' given sequence loaders"""
237+
if use_multiple_dataloaders:
238+
loaders = [
239+
torch.utils.data.DataLoader(TestIterableDataset(10), batch_size=2),
240+
torch.utils.data.DataLoader(TestIterableDataset(20), batch_size=2),
241+
]
242+
else:
243+
loaders = [
244+
torch.utils.data.DataLoader(TestIterableDataset(10), batch_size=2),
245+
]
246+
247+
combined_loader = CombinedLoader(loaders, mode)
248+
249+
has_break = False
250+
251+
for idx, item in enumerate(combined_loader):
252+
assert isinstance(item, Sequence)
253+
assert len(item) == 2 if use_multiple_dataloaders else 1
254+
if not use_multiple_dataloaders and idx == 4:
255+
has_break = True
256+
break
257+
258+
if mode == "max_size_cycle":
259+
assert combined_loader.loaders[0].state.done == (not has_break)
260+
expected = (10 if mode == "max_size_cycle" else 5) if use_multiple_dataloaders else 5
261+
assert (expected - 1) == idx, (mode, use_multiple_dataloaders)
262+
263+
264+
@pytest.mark.parametrize("lengths", [[4, 6], [5, 5], [6, 4]])
265+
def test_combined_loader_sequence_with_map_and_iterable(lengths):
266+
class MyIterableDataset(IterableDataset):
267+
def __init__(self, size: int = 10):
268+
self.size = size
269+
270+
def __iter__(self):
271+
self.sampler = SequentialSampler(range(self.size))
272+
self.iter_sampler = iter(self.sampler)
273+
return self
274+
275+
def __next__(self):
276+
return next(self.iter_sampler)
277+
278+
class MyMapDataset(Dataset):
279+
def __init__(self, size: int = 10):
280+
self.size = size
281+
282+
def __getitem__(self, index):
283+
return index
284+
285+
def __len__(self):
286+
return self.size
287+
288+
x, y = lengths
289+
loaders = [DataLoader(MyIterableDataset(x)), DataLoader(MyMapDataset(y))]
290+
dataloader = CombinedLoader(loaders, mode="max_size_cycle")
291+
counter = 0
292+
for _ in dataloader:
293+
counter += 1
294+
assert counter == max(x, y)
295+
296+
219297
def test_combined_loader_sequence_max_size_cycle():
220298
"""Test `CombinedLoader` of mode 'max_size_cycle' given sequence loaders"""
221299
loaders = [

0 commit comments

Comments
 (0)