@@ -106,7 +106,7 @@ def collate_fn(examples):
106106 mmd .trace_inputs = trace_inputs
107107
108108 self .model_group = mmd .model_group
109- self ._batch_list : list [ Tensor | None ] = []
109+ self ._curr_batch : Tensor = None
110110
111111 def configure (
112112 self , micro_batch_size : int , device : torch .device , in_memory : bool , replay : int
@@ -133,8 +133,9 @@ def _inner_send_b2d(batch):
133133
134134 if not self ._in_memory :
135135 self ._send_batch_to_device = _inner_send_b2d
136- batch = next (self .data_iter )
137- self ._batch_list .append (batch )
136+ # set the first batch to _curr_batch so that the end of replay can
137+ # be checked at the same time when the last batch is returned
138+ self ._curr_batch = next (self .data_iter )
138139 return
139140
140141 # do nothing in case of in-memory loading
@@ -147,8 +148,10 @@ def _inner_send_b2d(batch):
147148 self .batches .append (batch )
148149
149150 self .data_iter = iter (self .batches )
150- batch = next (self .data_iter )
151- self ._batch_list .append (batch )
151+
152+ # set the first batch to _curr_batch so that the end of replay can
153+ # be checked at the same time when the last batch is returned
154+ self ._curr_batch = next (self .data_iter )
152155
153156 def _handle_dataset_playback (self ) -> Tensor | None :
154157 if self ._replay == 0 :
@@ -166,20 +169,20 @@ def _handle_dataset_playback(self) -> Tensor | None:
166169
167170 def next_batch (self ) -> tuple [Tensor , bool ]:
168171 """Return next data tensor and bool if last bach."""
172+ # take a batch to return
173+ curr_batch = self ._curr_batch
174+ # noop for in-memory case; otherwise, load batch to a correct device
175+ self ._send_batch_to_device (curr_batch )
176+
177+ # load a new batch to _curr_batch
169178 try :
170- batch = next (self .data_iter )
171- self ._batch_list .append (batch )
179+ self ._curr_batch = next (self .data_iter )
172180 except StopIteration :
173- batch = self ._handle_dataset_playback ()
174- self ._batch_list .append (batch )
175-
176- batch = self ._batch_list .pop (0 )
177- # noop for in-memory case; otherwise, load batch to a correct device
178- self ._send_batch_to_device (batch )
181+ self ._curr_batch = self ._handle_dataset_playback ()
179182
180- is_last = self ._batch_list [ 0 ] is None
183+ is_last = self ._curr_batch is None
181184
182- return batch , is_last
185+ return curr_batch , is_last
183186
184187 @staticmethod
185188 def create_image_dataset (
0 commit comments