Skip to content

Commit 28ddcc1

Browse files
committed
update scheduler
1 parent 4cd8c85 commit 28ddcc1

File tree

3 files changed

+6
-10
lines changed

3 files changed

+6
-10
lines changed

lmdeploy/pytorch/engine/engine.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -250,9 +250,6 @@ def do_prefill_dp(self):
250250
return ret
251251

252252
def do_prefill_default(self):
253-
if self.spec_decoding:
254-
return True
255-
256253
# decoding if no waiting
257254
scheduler = self.scheduler
258255
if not scheduler.has_waiting():
@@ -298,7 +295,7 @@ async def prefetch_next_inputs(self):
298295
else:
299296
num_running = scheduler.num_running()
300297
is_decoding = self.forward_inputs['inputs'].is_decoding
301-
running_threshold = (self.scheduler_config.max_batches // 4) if is_decoding else 0
298+
running_threshold = (self.scheduler_config.max_batches // 4) if is_decoding or self.spec_decoding else 0
302299

303300
if num_running > running_threshold:
304301
enable = True
@@ -1269,7 +1266,6 @@ async def _async_loop_main(
12691266
if idx == num_loops - 1:
12701267
scheduler.collect_migration_done()
12711268
forward_inputs, next_running = await inputs_maker.prefetch_next_inputs()
1272-
12731269
# send output
12741270
out = await self.executor.get_output_async()
12751271
if out is not None:

lmdeploy/pytorch/engine/model_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -706,7 +706,7 @@ def gather(self, output):
706706
def get_output(self):
707707
"""Get tmp_output."""
708708
if not return_logits:
709-
return self._output[:, -1:]
709+
return self._output[:, -1:], None
710710
torch.cuda.synchronize()
711711
return self._output, self._aux_output
712712

lmdeploy/pytorch/paging/scheduler.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -189,14 +189,14 @@ def _schedule_prefill(self):
189189
copy_map: Dict[int, int] = dict()
190190
running: SeqList = []
191191
token_count = 0
192+
prealloc_size = self.num_spec_tokens or self.num_spec_tokens - 1
192193

193194
def _to_running(seq: SchedulerSequence):
194195
"""To running."""
195196
seq.status = MessageStatus.RUNNING
196197
running.append(seq)
197198
nonlocal token_count
198199
token_count += seq.num_token_ids
199-
token_count += self.num_spec_tokens
200200
token_count += len(seq.spec_token_ids)
201201

202202
def __evict_for_seq(seq: SchedulerSequence, waiting):
@@ -205,7 +205,7 @@ def __evict_for_seq(seq: SchedulerSequence, waiting):
205205
hanging = reversed(self.hanging)
206206
waiting = reversed(waiting)
207207
evictable = list(chain(hanging, waiting))
208-
return eviction_helper.evict_for_seq(seq, evictable, prealloc_size=self.num_spec_tokens)
208+
return eviction_helper.evict_for_seq(seq, evictable, prealloc_size=prealloc_size)
209209

210210
def _reorder_waiting():
211211
"""Reorder waiting."""
@@ -218,7 +218,7 @@ def _reorder_waiting():
218218
waiting = _reorder_waiting()
219219
while len(waiting) > 0 and len(running) < max_batches:
220220
seq = waiting.pop(0)
221-
cur_token_count = token_count + seq.num_token_ids + self.num_spec_tokens + len(seq.spec_token_ids)
221+
cur_token_count = token_count + seq.num_token_ids + len(seq.spec_token_ids)
222222
if (len(running) > 0 and cur_token_count > self.cache_config.max_prefill_token_num):
223223
break
224224

@@ -228,7 +228,7 @@ def _reorder_waiting():
228228
break
229229

230230
# allocate session memory
231-
self.block_manager.allocate(seq, prealloc_size=self.num_spec_tokens)
231+
self.block_manager.allocate(seq, prealloc_size=prealloc_size)
232232
_to_running(seq)
233233

234234
seq.record_event(EventType.SCHEDULED)

0 commit comments

Comments
 (0)