Skip to content

Commit 64d84cc

Browse files
mukherylantiga
authored andcommitted
avoid unnecessary workers with sequential CombinedLoader (#17639)
Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]> (cherry picked from commit c3ad756)
1 parent a81a956 commit 64d84cc

File tree

4 files changed

+51
-7
lines changed

4 files changed

+51
-7
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1414

1515
### Fixed
1616

17+
- `CombinedLoader` only starts DataLoader workers when necessary when operating in sequential mode ([#17639](https://github.com/Lightning-AI/lightning/pull/17639))
18+
19+
1720
- Fixed a potential bug with uploading model checkpoints to Neptune.ai by uploading files from stream ([#17430](https://github.com/Lightning-AI/lightning/pull/17430))
1821

1922

src/lightning/pytorch/utilities/combined_loader.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def limits(self, limits: Optional[List[Union[int, float]]]) -> None:
108108
self._limits = limits
109109

110110
def __next__(self) -> Tuple[Any, int, int]:
111-
n = len(self.iterators)
111+
n = len(self.iterables)
112112
if n == 0 or self._iterator_idx >= n:
113113
raise StopIteration
114114

@@ -120,7 +120,7 @@ def __next__(self) -> Tuple[Any, int, int]:
120120
raise StopIteration
121121

122122
try:
123-
out = next(self.iterators[self._iterator_idx])
123+
out = next(self.iterators[0])
124124
index = self._idx
125125
self._idx += 1
126126
# batch, batch_idx, dataloader_idx
@@ -131,19 +131,28 @@ def __next__(self) -> Tuple[Any, int, int]:
131131
return self.__next__()
132132

133133
def __iter__(self) -> Self:
134-
super().__iter__()
135134
self._iterator_idx = 0
136135
self._idx = 0
136+
self._load_current_iterator()
137137
return self
138138

139139
def reset(self) -> None:
140140
super().reset()
141141
self._iterator_idx = 0
142142
self._idx = 0
143143

144+
def _load_current_iterator(self) -> None:
145+
# Load a single DataLoader, prevents multiple sets of workers from starting unnecessarily
146+
if self._iterator_idx < len(self.iterables):
147+
self.iterators = [iter(self.iterables[self._iterator_idx])]
148+
else:
149+
# No more iterables to step through, return an empty list
150+
self.iterators = []
151+
144152
def _use_next_iterator(self) -> None:
145153
self._iterator_idx += 1
146154
self._idx = 0
155+
self._load_current_iterator()
147156

148157

149158
class _MaxSize(_ModeIterator[List]):

tests/tests_pytorch/loops/test_loops.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -844,8 +844,7 @@ def _get_iterator(self):
844844
# iterable check
845845
0,
846846
# epoch ends
847-
1,
848-
# teardown
847+
0,
849848
1,
850849
]
851850
else:
@@ -855,9 +854,8 @@ def _get_iterator(self):
855854
# iterable check
856855
0,
857856
# epoch ends
857+
0,
858858
1,
859859
2,
860-
# teardown
861-
3,
862860
]
863861
assert val_dataloader.shutdown_workers_epochs == expected

tests/tests_pytorch/utilities/test_combined_loader.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,40 @@ def test_combined_loader_sequence_iterable_dataset(mode, use_multiple_dataloader
306306
assert idx == expected - 1
307307

308308

309+
@pytest.mark.parametrize("mode", ["min_size", "max_size_cycle", "max_size", "sequential"])
310+
def test_combined_loader_simultaneous_workers(mode):
311+
"""Test `CombinedLoader` to check how it initializes dataloader workers."""
312+
313+
class TestDataLoader(DataLoader):
314+
def __init__(self, *args, **kwargs):
315+
super().__init__(*args, **kwargs)
316+
self.workers_active = False
317+
318+
def _get_iterator(self):
319+
self.workers_active = True
320+
return super()._get_iterator()
321+
322+
def _shutdown_workers(self):
323+
self.workers_active = False
324+
super()._shutdown_workers()
325+
326+
loaders = [
327+
TestDataLoader(range(10), batch_size=2, num_workers=0),
328+
TestDataLoader(range(20), batch_size=2, num_workers=0),
329+
]
330+
combined_loader = CombinedLoader(loaders, mode)
331+
# Start the dataloader
332+
_ = iter(combined_loader)
333+
334+
workers_active = []
335+
for loader in loaders:
336+
workers_active.append(loader.workers_active)
337+
338+
# Sequential only starts the first dataloader, other modes start both
339+
expected = [True, False] if mode == "sequential" else [True, True]
340+
assert workers_active == expected
341+
342+
309343
@pytest.mark.parametrize(
310344
("limits", "expected"),
311345
[

0 commit comments

Comments
 (0)