16
16
from typing import Any , Callable , Dict , Iterator , List , Literal , Optional , Tuple , Type , Union
17
17
18
18
from torch .utils .data .dataloader import _BaseDataLoaderIter , _MultiProcessingDataLoaderIter
19
- from typing_extensions import Self , TypedDict
19
+ from typing_extensions import Self , TypedDict , override
20
20
21
21
from lightning .fabric .utilities .data import sized_len
22
22
from lightning .pytorch .utilities ._pytree import _map_and_unflatten , _tree_flatten , tree_unflatten
@@ -33,9 +33,11 @@ def __init__(self, iterables: List[Iterable], limits: Optional[List[Union[int, f
33
33
self ._idx = 0 # what would be batch_idx
34
34
self .limits = limits
35
35
36
+ @override
36
37
def __next__ (self ) -> _ITERATOR_RETURN :
37
38
raise NotImplementedError
38
39
40
+ @override
39
41
def __iter__ (self ) -> Self :
40
42
self .iterators = [iter (iterable ) for iterable in self .iterables ]
41
43
self ._idx = 0
@@ -66,6 +68,7 @@ def __init__(self, iterables: List[Iterable], limits: Optional[List[Union[int, f
66
68
super ().__init__ (iterables , limits )
67
69
self ._consumed : List [bool ] = []
68
70
71
+ @override
69
72
def __next__ (self ) -> _ITERATOR_RETURN :
70
73
n = len (self .iterators )
71
74
out = [None ] * n # values per iterator
@@ -83,29 +86,34 @@ def __next__(self) -> _ITERATOR_RETURN:
83
86
self ._idx += 1
84
87
return out , index , 0
85
88
89
+ @override
86
90
def __iter__ (self ) -> Self :
87
91
super ().__iter__ ()
88
92
self ._consumed = [False ] * len (self .iterables )
89
93
return self
90
94
95
+ @override
91
96
def __len__ (self ) -> int :
92
97
lengths = _get_iterables_lengths (self .iterables )
93
98
if self .limits is not None :
94
99
return max (min (length , limit ) for length , limit in zip (lengths , self .limits )) # type: ignore[return-value]
95
100
return max (lengths ) # type: ignore[return-value]
96
101
102
+ @override
97
103
def reset (self ) -> None :
98
104
super ().reset ()
99
105
self ._consumed = []
100
106
101
107
102
108
class _MinSize (_ModeIterator ):
109
+ @override
103
110
def __next__ (self ) -> _ITERATOR_RETURN :
104
111
out = [next (it ) for it in self .iterators ]
105
112
index = self ._idx
106
113
self ._idx += 1
107
114
return out , index , 0
108
115
116
+ @override
109
117
def __len__ (self ) -> int :
110
118
lengths = _get_iterables_lengths (self .iterables )
111
119
return min (lengths + self .limits ) if self .limits is not None else min (lengths ) # type: ignore[return-value]
@@ -116,6 +124,7 @@ def __init__(self, iterables: List[Iterable], limits: Optional[List[Union[int, f
116
124
super ().__init__ (iterables , limits )
117
125
self ._iterator_idx = 0 # what would be dataloader_idx
118
126
127
+ @override
119
128
def __next__ (self ) -> _ITERATOR_RETURN :
120
129
n = len (self .iterables )
121
130
if n == 0 or self ._iterator_idx >= n :
@@ -138,18 +147,21 @@ def __next__(self) -> _ITERATOR_RETURN:
138
147
self ._idx += 1
139
148
return out , index , self ._iterator_idx
140
149
150
+ @override
141
151
def __iter__ (self ) -> Self :
142
152
self ._iterator_idx = 0
143
153
self ._idx = 0
144
154
self ._load_current_iterator ()
145
155
return self
146
156
157
+ @override
147
158
def __len__ (self ) -> int :
148
159
lengths = _get_iterables_lengths (self .iterables )
149
160
if self .limits is not None :
150
161
return sum (min (length , limit ) for length , limit in zip (lengths , self .limits )) # type: ignore[misc]
151
162
return sum (lengths ) # type: ignore[arg-type]
152
163
164
+ @override
153
165
def reset (self ) -> None :
154
166
super ().reset ()
155
167
self ._iterator_idx = 0
@@ -169,6 +181,7 @@ def _use_next_iterator(self) -> None:
169
181
170
182
171
183
class _MaxSize (_ModeIterator ):
184
+ @override
172
185
def __next__ (self ) -> _ITERATOR_RETURN :
173
186
n = len (self .iterators )
174
187
out = [None ] * n
@@ -183,6 +196,7 @@ def __next__(self) -> _ITERATOR_RETURN:
183
196
self ._idx += 1
184
197
return out , index , 0
185
198
199
+ @override
186
200
def __len__ (self ) -> int :
187
201
lengths = _get_iterables_lengths (self .iterables )
188
202
if self .limits is not None :
@@ -329,6 +343,7 @@ def __next__(self) -> _ITERATOR_RETURN:
329
343
out , batch_idx , dataloader_idx = out
330
344
return tree_unflatten (out , self ._spec ), batch_idx , dataloader_idx
331
345
346
+ @override
332
347
def __iter__ (self ) -> Self :
333
348
cls = _SUPPORTED_MODES [self ._mode ]["iterator" ]
334
349
iterator = cls (self .flattened , self ._limits )
0 commit comments