Skip to content

Commit 4192053

Browse files
authored
fix: dataset replay last batch (#300)
When using request generator with less sending rate, the server send and server recv were out of sync. With previous implementation, dataset would return None when there were no batches. For that reason, due to longer wait time given by the slower rate server finished processing last batches before getting the info that these are done. The fix includes early detection of the last batch and send that info with the batches, so the server would know which batch is the last one.
1 parent 33eda74 commit 4192053

File tree

3 files changed

+22
-17
lines changed

3 files changed

+22
-17
lines changed

infscale/execution/pipeline.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -273,10 +273,9 @@ async def _server_send(self, router: Router):
273273
self._end_of_send = False
274274

275275
async def _inner_send(batches: list[torch.Tensor | None]) -> None:
276-
for batch in batches:
277-
if batch is None:
276+
for batch, is_last in batches:
277+
if is_last:
278278
self._end_of_send = True
279-
break
280279

281280
await self._wait_tx_permission()
282281

infscale/module/dataset.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def collate_fn(examples):
106106
mmd.trace_inputs = trace_inputs
107107

108108
self.model_group = mmd.model_group
109+
self._batch_list: list[Tensor | None] = []
109110

110111
def configure(
111112
self, micro_batch_size: int, device: torch.device, in_memory: bool, replay: int
@@ -132,6 +133,8 @@ def _inner_send_b2d(batch):
132133

133134
if not self._in_memory:
134135
self._send_batch_to_device = _inner_send_b2d
136+
batch = next(self.data_iter)
137+
self._batch_list.append(batch)
135138
return
136139

137140
# do nothing in case of in-memory loading
@@ -144,11 +147,12 @@ def _inner_send_b2d(batch):
144147
self.batches.append(batch)
145148

146149
self.data_iter = iter(self.batches)
150+
batch = next(self.data_iter)
151+
self._batch_list.append(batch)
147152

148153
def _handle_dataset_playback(self) -> Tensor | None:
149154
if self._replay == 0:
150155
return None
151-
152156
# this ensures self._replay decreases to zero or
153157
# stays as -1 (infinite)
154158
self._replay = max(self._replay - 1, -1)
@@ -160,20 +164,22 @@ def _handle_dataset_playback(self) -> Tensor | None:
160164

161165
return next(self.data_iter)
162166

163-
def next_batch(self) -> Tensor | None:
164-
"""Return next data tensor.
165-
166-
Once all the data is consumed, it returns None.
167-
"""
167+
def next_batch(self) -> tuple[Tensor, bool]:
168+
"""Return next data tensor and bool if last bach."""
168169
try:
169170
batch = next(self.data_iter)
171+
self._batch_list.append(batch)
170172
except StopIteration:
171173
batch = self._handle_dataset_playback()
174+
self._batch_list.append(batch)
172175

176+
batch = self._batch_list.pop(0)
173177
# noop for in-memory case; otherwise, load batch to a correct device
174178
self._send_batch_to_device(batch)
175179

176-
return batch
180+
is_last = self._batch_list[0] is None
181+
182+
return batch, is_last
177183

178184
@staticmethod
179185
def create_image_dataset(

infscale/request/generator.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ async def get(self) -> list[Tensor | None]:
6161
class DefaultGenerator(Generator):
6262
"""DefaultGenerator class."""
6363

64-
async def get(self) -> list[Tensor | None]:
64+
async def get(self) -> list[tuple[Tensor, bool]]:
6565
"""Return one batch of requests as a list.
6666
6767
initialize() method must be called once before calling this method.
@@ -99,10 +99,10 @@ async def _generate(self) -> None:
9999
await self._gen_evt.wait()
100100

101101
while True:
102-
batch = self._dataset.next_batch()
103-
await self._queue.put(batch)
102+
batch, is_last = self._dataset.next_batch()
103+
await self._queue.put((batch, is_last))
104104

105-
if batch is None:
105+
if is_last:
106106
break
107107

108108
self._mc.update(self._seqno)
@@ -114,7 +114,7 @@ async def _generate(self) -> None:
114114
def _compute_iat(self):
115115
return np.random.exponential(scale=1 / self._batch_rate)
116116

117-
async def get(self) -> list[Tensor | None]:
117+
async def get(self) -> list[tuple[Tensor, bool]]:
118118
"""Return one batch of requests.
119119
120120
initialize() method must be called once before calling this method.
@@ -124,8 +124,8 @@ async def get(self) -> list[Tensor | None]:
124124
batches = []
125125
while True:
126126
# this guarantees at least one batch of requests is returned
127-
batch = await self._queue.get()
128-
batches.append(batch)
127+
batch, is_last = await self._queue.get()
128+
batches.append((batch, is_last))
129129

130130
if self._queue.empty():
131131
break

0 commit comments

Comments
 (0)