Skip to content

Commit 1363410

Browse files
authored
misc: simplify next_batch code (#304)
There is no need to maintain _batch_list list since only two values need to be tracked. next_batch() method is refactored to make it more readable.
1 parent 4192053 commit 1363410

File tree

1 file changed

+18
-15
lines changed

1 file changed

+18
-15
lines changed

infscale/module/dataset.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def collate_fn(examples):
106106
mmd.trace_inputs = trace_inputs
107107

108108
self.model_group = mmd.model_group
109-
self._batch_list: list[Tensor | None] = []
109+
self._curr_batch: Tensor = None
110110

111111
def configure(
112112
self, micro_batch_size: int, device: torch.device, in_memory: bool, replay: int
@@ -133,8 +133,9 @@ def _inner_send_b2d(batch):
133133

134134
if not self._in_memory:
135135
self._send_batch_to_device = _inner_send_b2d
136-
batch = next(self.data_iter)
137-
self._batch_list.append(batch)
136+
# set the first batch to _curr_batch so that the end of replay can
137+
# be checked at the same time when the last batch is returned
138+
self._curr_batch = next(self.data_iter)
138139
return
139140

140141
# do nothing in case of in-memory loading
@@ -147,8 +148,10 @@ def _inner_send_b2d(batch):
147148
self.batches.append(batch)
148149

149150
self.data_iter = iter(self.batches)
150-
batch = next(self.data_iter)
151-
self._batch_list.append(batch)
151+
152+
# set the first batch to _curr_batch so that the end of replay can
153+
# be checked at the same time when the last batch is returned
154+
self._curr_batch = next(self.data_iter)
152155

153156
def _handle_dataset_playback(self) -> Tensor | None:
154157
if self._replay == 0:
@@ -166,20 +169,20 @@ def _handle_dataset_playback(self) -> Tensor | None:
166169

167170
def next_batch(self) -> tuple[Tensor, bool]:
168171
"""Return next data tensor and bool if last bach."""
172+
# take a batch to return
173+
curr_batch = self._curr_batch
174+
# noop for in-memory case; otherwise, load batch to a correct device
175+
self._send_batch_to_device(curr_batch)
176+
177+
# load a new batch to _curr_batch
169178
try:
170-
batch = next(self.data_iter)
171-
self._batch_list.append(batch)
179+
self._curr_batch = next(self.data_iter)
172180
except StopIteration:
173-
batch = self._handle_dataset_playback()
174-
self._batch_list.append(batch)
175-
176-
batch = self._batch_list.pop(0)
177-
# noop for in-memory case; otherwise, load batch to a correct device
178-
self._send_batch_to_device(batch)
181+
self._curr_batch = self._handle_dataset_playback()
179182

180-
is_last = self._batch_list[0] is None
183+
is_last = self._curr_batch is None
181184

182-
return batch, is_last
185+
return curr_batch, is_last
183186

184187
@staticmethod
185188
def create_image_dataset(

0 commit comments

Comments
 (0)