1414
1515import os
1616from collections .abc import Iterable , Iterator , Mapping , Sequence
17+ from dataclasses import dataclass , field
1718from functools import partial
18- from typing import Any , Callable , Dict , Generator , Optional , Tuple , Union
19+ from typing import Any , Callable , Dict , Generator , List , Optional , Tuple , Union
1920
2021import torch
2122from torch import Tensor
@@ -170,12 +171,35 @@ def to_disk(self) -> None:
170171 torch .save (outputs , fp )
171172
172173
174+ @dataclass
175+ class SharedCycleIteratorState :
176+
177+ mode : str = "max_size_cycle"
178+ dataloaders : List [DataLoader ] = field (default_factory = lambda : [])
179+ has_finished : Dict [int , bool ] = field (default_factory = lambda : {})
180+ has_reset : bool = False
181+
182+ def reset (self ) -> None :
183+ for dataloader in self .dataloaders :
184+ self .has_finished [id (dataloader )] = False
185+ self .has_reset = True
186+
187+ @property
188+ def done (self ) -> bool :
189+ if not self .has_reset :
190+ raise MisconfigurationException ("Please, call reset once all dataloaders have been added." )
191+ if len (self .dataloaders ) == 1 :
192+ return False
193+ decision_fn = all if self .mode == "max_size_cycle" else any
194+ return decision_fn (self .has_finished .values ())
195+
196+
173197class CycleIterator :
174198 """
175199 Iterator for restarting a dataloader if it runs out of samples
176200 """
177201
178- def __init__ (self , loader : Any , length : Optional [int ] = None ):
202+ def __init__ (self , loader : Any , length : Optional [int ] = None , state : SharedCycleIteratorState = None ):
179203 """
180204 Args:
181205 loader: the loader to restart for cyclic (and optionally infinite) sampling
@@ -185,6 +209,15 @@ def __init__(self, loader: Any, length: Optional[int] = None):
185209 if length is None :
186210 length = float ("inf" )
187211
212+ if not state :
213+ state = SharedCycleIteratorState ()
214+ state .dataloaders .append (loader )
215+ state .reset ()
216+ else :
217+ state .dataloaders .append (loader )
218+
219+ self .state = state
220+
188221 self .length = length
189222 self .loader = loader
190223 self ._loader_iter = None
@@ -205,21 +238,27 @@ def __next__(self) -> Any:
205238 """
206239 Fetches the next batch from internal dataloader and restarts
207240 it if necessary
208-
209241 Returns:
210242 Any: the resulting batch
211-
212243 Raises:
213244 StopIteration: if more then :attr:`length` batches have been returned
214245 """
215246 # Note: if self.length is `inf`, then the iterator will never stop
216- if self .counter >= self .__len__ ():
247+ if self .counter >= self .__len__ () or self . state . done :
217248 raise StopIteration
218249
219250 try :
220251 return next (self ._loader_iter )
221252
222253 except StopIteration :
254+
255+ # inform the shared state this loader has completed
256+ self .state .has_finished [id (self .loader )] = True
257+
258+ # check if iteration should be stopped.
259+ if self .state .done :
260+ raise StopIteration
261+
223262 self ._loader_iter = iter (self .loader )
224263 return next (self ._loader_iter )
225264
@@ -468,10 +507,14 @@ def _wrap_loaders_max_size_cycle(self) -> Any:
468507
469508 # multiple loaders
470509 if isinstance (self .loaders , (Sequence , Mapping )):
510+ state = SharedCycleIteratorState ()
511+
471512 self .loaders = apply_to_collection (
472- self .loaders , Iterable , CycleIterator , length = length , wrong_dtype = (Sequence , Mapping )
513+ self .loaders , Iterable , CycleIterator , length = length , state = state , wrong_dtype = (Sequence , Mapping )
473514 )
474515
516+ state .reset ()
517+
475518 def __iter__ (self ) -> Any :
476519 """
477520 Create and return an iterator, `CombinedLoaderIterator`, for the combined loader.
0 commit comments