Skip to content

Commit 7e2c8b8

Browse files
committed
add warmup to draft and support long input split
1 parent c7282c4 commit 7e2c8b8

File tree

8 files changed

+230
-180
lines changed

8 files changed

+230
-180
lines changed

benchmark/benchmark_serving.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ def get_launching_server_cmd(model_path, backend, server_config):
2020
# Convert snake_case to kebab-case for command line args
2121
key = key.replace('_', '-')
2222
cmd.append(f'--{key}')
23-
cmd.append(str(value))
23+
if str(value):
24+
cmd.append(str(value))
2425
# Special handling for proxy server case
2526
if server_config.get('proxy_url') and server_config.get('dp'):
2627
cmd.append('--allow-terminate-by-client')
@@ -66,9 +67,9 @@ def get_server_ip_port(backend: str, server_config: Dict) -> Tuple[str, int]:
6667
server_ip = server_config.get('server_ip', '0.0.0.0')
6768
server_port = server_config.get('server_port', 23333)
6869
elif backend == 'sglang':
69-
return (server_config.get('server_ip', '0.0.0.0'), server_config.get('server_port', 30000))
70+
return (server_config.get('server_ip', '0.0.0.0'), server_config.get('port', 30000))
7071
elif backend == 'vllm':
71-
return (server_config.get('server_ip', '0.0.0.0'), server_config.get('server_port', 8000))
72+
return (server_config.get('server_ip', '0.0.0.0'), server_config.get('port', 8000))
7273
else:
7374
raise ValueError(f'unknown backend: {backend}')
7475
return server_ip, server_port

lmdeploy/pytorch/configurations/llama.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def build(cls, hf_config, model_path: str = None, is_draft_model: bool = False,
1919
# update draft model arch
2020
assert speculative_config is not None
2121
hf_config.architectures[0] = speculative_config.method.capitalize() + hf_config.architectures[0]
22+
cfg.vocab_size = getattr(hf_config, 'draft_vocab_size', hf_config.vocab_size)
2223
elif speculative_config is not None:
2324
# add aux_hidden_state_layers for eagle3
2425
if speculative_config.method == 'eagle3':

lmdeploy/pytorch/engine/engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -910,7 +910,7 @@ def update_running_migration(self, running: SeqList, next_token_ids: np.ndarray,
910910

911911
def _make_spec_stats(self, seqs: SeqList, next_token_ids: torch.LongTensor):
912912
"""Make spec stats."""
913-
debug = True
913+
debug = False
914914
all_stats = [None] * len(seqs)
915915
if self.speculative_config is not None and (debug or self.engine_config.enable_metrics):
916916
if debug and not hasattr(self, 'spec_stats'):

lmdeploy/pytorch/engine/model_agent.py

Lines changed: 129 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,7 @@ def __init__(
333333
self.method = specdecode_config.method
334334
self.model_config = specdecode_config.model_config
335335
self.cache_config = specdecode_config.cache_config
336+
self.num_spec_tokens = specdecode_config.num_speculative_tokens
336337
self.backend_config = backend_config
337338
self.device = device
338339

@@ -365,12 +366,17 @@ def build_graph_runner(self):
365366
def build_cache_engine(self, cache_stream: torch.cuda.Stream):
366367
"""Build cache engine."""
367368
if self.cache_config is not None:
368-
self.cache_engine = CacheEngine(self.cache_config, self.model_config, rank=0, tp_rank=0, world_size=1, cache_stream=cache_stream)
369+
self.cache_engine = CacheEngine(self.cache_config,
370+
self.model_config,
371+
rank=0,
372+
tp_rank=0,
373+
world_size=1,
374+
cache_stream=cache_stream)
369375

370376
def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap):
371377
"""Forward impl."""
372378
cache_swapping(self.cache_engine, swap_in_map=swap_in_map, swap_out_map=swap_out_map)
373-
output = self.proposer.propose(inputs, cache_engine=self.cache_engine, stream=self.stream)
379+
output = self.proposer._forward(inputs, cache_engine=self.cache_engine, stream=self.stream)
374380
return output
375381

376382
async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap):
@@ -385,32 +391,122 @@ async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_ou
385391
await asyncio.sleep(0)
386392
return output
387393

394+
async def _async_model_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap):
395+
"""Model forward.
396+
397+
Args:
398+
inputs (Dict): The input data comes from _make_inputs.
399+
swap_in_map (SwapMap): Cache maps to swap in.
400+
swap_out_map (SwapMap): Cache maps to swap out.
401+
"""
402+
max_prefill_token_num = self.cache_config.max_prefill_token_num
403+
swap_done = False
404+
405+
async def __forward(inputs):
406+
"""forward."""
407+
nonlocal swap_done, swap_in_map, swap_out_map
408+
if swap_done:
409+
return await self.async_forward(inputs, swap_in_map=dict(), swap_out_map=dict())
410+
else:
411+
swap_done = True
412+
return await self.async_forward(inputs, swap_in_map=swap_in_map, swap_out_map=swap_out_map)
413+
414+
async def __long_context_single_forward(new_inputs):
415+
"""One large sequence."""
416+
model_metas = new_inputs[0].model_metas
417+
for inp in new_inputs:
418+
inp.model_metas = model_metas
419+
output = await __forward(inp)
420+
model_metas = output.get('model_metas')
421+
return output
422+
423+
# make long context inputs
424+
is_long_context = inputs.input_ids.numel() > max_prefill_token_num and not inputs.is_decoding
425+
426+
if is_long_context:
427+
seq_len = inputs.seq_length
428+
batch_size = seq_len.size(0)
429+
assert batch_size == 1, 'Do not support batched long context.'
430+
inputs_li = inputs.split(max_prefill_token_num)
431+
outputs = await __long_context_single_forward(inputs_li)
432+
else:
433+
outputs = await __forward(inputs)
434+
435+
loop_count = self.num_spec_tokens - 1
436+
draft_token_ids, model_metas, target_hidden_states = self.proposer.get_outputs(outputs, inputs)
437+
draft_tokens_li = [draft_token_ids]
438+
if loop_count > 0:
439+
inputs = self.proposer.update_inputs_decoding(inputs, draft_token_ids.transpose(0, 1), target_hidden_states,
440+
model_metas)
441+
for loop_idx in range(loop_count):
442+
outputs = await self.async_forward(inputs, swap_in_map=dict(), swap_out_map=dict())
443+
draft_token_ids, model_metas, target_hidden_states = self.proposer.get_outputs(outputs, inputs)
444+
draft_tokens_li.append(draft_token_ids)
445+
if loop_idx < loop_count - 1:
446+
inputs.update(draft_token_ids.transpose(0, 1))
447+
inputs.model_metas = model_metas
448+
inputs.target_hidden_states = target_hidden_states
449+
if inputs.target_position_ids is not None:
450+
inputs.target_position_ids += 1
451+
452+
return torch.cat(draft_tokens_li, dim=-1)
453+
388454
async def async_model_forward(self,
389455
model_inputs: ModelInputs,
390456
spec_inputs: SpecDecodeInputs,
391457
swap_in_map: SwapMap = dict(),
392458
swap_out_map: SwapMap = dict()):
393459
"""Draft model forward."""
394-
if model_inputs.spec_metadata.draft_token_ids is not None:
395-
spec_metadata = model_inputs.spec_metadata
396-
output_token_ids, num_rejected_tokens, last_token_ids = self.rejection_sampler(
397-
spec_inputs.target_logits, spec_metadata.draft_token_ids, spec_inputs.bonus_token_ids,
398-
spec_metadata.num_draft_tokens, spec_metadata.max_spec_len)
399-
spec_inputs.num_rejected_tokens = num_rejected_tokens
400-
spec_inputs.reject_sample_tokens = output_token_ids
401-
spec_inputs.next_token_ids = last_token_ids
402-
else:
403-
spec_inputs.next_token_ids = spec_inputs.bonus_token_ids
404-
output_token_ids = spec_inputs.next_token_ids.unsqueeze(-1)
460+
with torch.cuda.stream(self.stream):
461+
if model_inputs.spec_metadata.draft_token_ids is not None:
462+
spec_metadata = model_inputs.spec_metadata
463+
output_token_ids, num_rejected_tokens, last_token_ids = self.rejection_sampler(
464+
spec_inputs.target_logits, spec_metadata.draft_token_ids, spec_inputs.bonus_token_ids,
465+
spec_metadata.num_draft_tokens, spec_metadata.max_spec_len)
466+
spec_inputs.num_rejected_tokens = num_rejected_tokens
467+
spec_inputs.reject_sample_tokens = output_token_ids
468+
spec_inputs.next_token_ids = last_token_ids
469+
else:
470+
spec_inputs.next_token_ids = spec_inputs.bonus_token_ids
471+
output_token_ids = spec_inputs.next_token_ids.unsqueeze(-1)
405472

406-
with record_function('draft_prepare_inputs'):
407-
draft_model_inputs = self.proposer.prepare_inputs(model_inputs, spec_inputs)
473+
with record_function('draft_prepare_inputs'):
474+
draft_model_inputs = self.proposer.prepare_inputs(model_inputs, spec_inputs)
408475

409-
new_draft_tokens = await self.async_forward(draft_model_inputs,
410-
swap_in_map=swap_in_map,
411-
swap_out_map=swap_out_map)
412-
outputs = dict(output_token_ids=output_token_ids, spec_token_ids=new_draft_tokens)
413-
return outputs
476+
new_draft_tokens = await self._async_model_forward(draft_model_inputs,
477+
swap_in_map=swap_in_map,
478+
swap_out_map=swap_out_map)
479+
outputs = dict(output_token_ids=output_token_ids, spec_token_ids=new_draft_tokens)
480+
return outputs
481+
482+
def warmup(self, max_batches: int, target_model_config: ModelConfig):
483+
"""warmup."""
484+
target_hidden_size = self.proposer.get_target_hidden_size(target_model_config)
485+
486+
# warmup prefill
487+
inputs = ModelInputs.make_dummy(max_batches,
488+
is_decoding=False,
489+
device='cuda',
490+
vocab_size=self.model_config.vocab_size)
491+
inputs.target_hidden_states = torch.randn((1, max_batches, target_hidden_size),
492+
dtype=self.model_config.dtype,
493+
device='cuda')
494+
self._forward_impl(inputs, swap_in_map=dict(), swap_out_map=dict())
495+
496+
capture_batch_sizes = self.proposer.model.get_capture_batch_sizes()
497+
capture_batch_sizes = sorted(capture_batch_sizes, reverse=True)
498+
499+
for batch_size in capture_batch_sizes:
500+
inputs = ModelInputs.make_dummy(
501+
batch_size,
502+
is_decoding=True,
503+
device='cuda',
504+
vocab_size=self.model_config.vocab_size,
505+
)
506+
inputs.target_hidden_states = torch.randn((1, batch_size, self.model_config.hidden_size),
507+
dtype=self.model_config.dtype,
508+
device='cuda')
509+
self._forward_impl(inputs, swap_in_map=dict(), swap_out_map=dict())
414510

415511

416512
class BaseModelAgent:
@@ -525,8 +621,9 @@ def get_free_mem(self):
525621
def warmup(self):
526622
"""warmup."""
527623
# TODO: disable for now, do not remove the comments.
528-
with self.all_context():
624+
with self.all_context(), torch.cuda.stream(self.stream), torch.inference_mode():
529625
max_batches = self.cache_config.max_batches
626+
530627
num_tokens = max_batches
531628

532629
# warmup prefill
@@ -546,6 +643,10 @@ def warmup(self):
546643
vocab_size=self.model_config.vocab_size)
547644
self._forward_impl(inputs, swap_in_map=dict(), swap_out_map=dict())
548645

646+
# warmup draft model
647+
if self.spec_agent is not None:
648+
self.spec_agent.warmup(max_batches, self.model_config)
649+
549650
async def _async_model_forward(
550651
self,
551652
inputs: ModelInputs,
@@ -639,8 +740,8 @@ async def __long_context_single_forward(new_inputs, max_seqlen: int):
639740
return tmp_out
640741

641742
# make long context inputs
642-
is_long_context = inputs.input_ids.numel(
643-
) > max_prefill_token_num and not inputs.is_decoding and inputs.seq_length[0] == 1
743+
is_long_context = inputs.input_ids.numel() > max_prefill_token_num and not inputs.is_decoding
744+
644745
max_seqlen = 0
645746
if is_long_context:
646747
seq_len = inputs.seq_length
@@ -1165,7 +1266,7 @@ def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map:
11651266
inputs,
11661267
self.cache_engine,
11671268
stream=self.stream,
1168-
output_position_ids=self.spec_agent is not None)
1269+
output_position_ids=False)
11691270
return output
11701271

11711272
async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap):
@@ -1194,6 +1295,10 @@ def reset_graph_runner(self):
11941295
if hasattr(self.patched_model, 'reset'):
11951296
self.patched_model.reset()
11961297

1298+
if self.spec_agent is not None:
1299+
if self.spec_agent.proposer.model is not None and hasattr(self.spec_agent.proposer.model, 'reset'):
1300+
self.spec_agent.proposer.model.reset()
1301+
11971302
@torch.inference_mode()
11981303
def update_params(self, request: UpdateParamsRequest):
11991304
"""Update params."""

lmdeploy/pytorch/model_inputs.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -274,22 +274,26 @@ def __make_next_vision_inputs(flatten_mms: List, start: int):
274274
max_q_seqlen = end - start
275275
if isinstance(max_q_seqlen, torch.Tensor):
276276
max_q_seqlen = max_q_seqlen.item()
277-
inp = ModelInputs(
278-
input_ids=self.input_ids[:, start:end],
279-
seq_length=input_ids.new_tensor([end - start]),
280-
block_offsets=self.block_offsets,
281-
history_lengths=self.history_lengths + start,
282-
is_decoding=self.is_decoding,
283-
num_ignored_history=self.num_ignored_history,
284-
max_q_seqlen=max_q_seqlen,
285-
max_kv_seqlen=max_kv_seqlen,
286-
sum_kv_seqlen=max_kv_seqlen,
287-
local_adapter_ids=self.local_adapter_ids,
288-
vision_inputs=vision_inputs,
289-
model_metas=self.model_metas,
290-
cross_length=cross_length,
291-
history_cross_length=history_cross_length,
292-
)
277+
target_hidden_states = self.target_hidden_states[:, start:
278+
end] if self.target_hidden_states is not None else None
279+
target_position_ids = self.target_position_ids[:,
280+
start:end] if self.target_position_ids is not None else None
281+
inp = ModelInputs(input_ids=self.input_ids[:, start:end],
282+
seq_length=input_ids.new_tensor([end - start]),
283+
block_offsets=self.block_offsets,
284+
history_lengths=self.history_lengths + start,
285+
is_decoding=self.is_decoding,
286+
num_ignored_history=self.num_ignored_history,
287+
max_q_seqlen=max_q_seqlen,
288+
max_kv_seqlen=max_kv_seqlen,
289+
sum_kv_seqlen=max_kv_seqlen,
290+
local_adapter_ids=self.local_adapter_ids,
291+
vision_inputs=vision_inputs,
292+
model_metas=self.model_metas,
293+
cross_length=cross_length,
294+
history_cross_length=history_cross_length,
295+
target_hidden_states=target_hidden_states,
296+
target_position_ids=target_position_ids)
293297
ret.append(inp)
294298
history_cross_length = cross_length
295299
max_kv_seqlen += max_q_seqlen

lmdeploy/pytorch/spec_decode/base.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2-
from typing import Any, List, Optional
2+
from typing import Any, Dict, List, Optional
33

44
import torch
55
from mmengine import Registry
6+
from torch.profiler import record_function
67

78
from lmdeploy.utils import get_logger
89

@@ -79,13 +80,23 @@ def build_model(self, empty_init: bool, target_model: torch.nn.Module = None):
7980
self.model = patched_model
8081
self.target_model = target_model
8182

82-
def propose(self, model_inputs: ModelInputs, cache_engine: CacheEngine = None, stream: torch.cuda.Stream = None):
83+
def get_outputs(self, model_outputs: Dict[str, torch.Tensor], model_inputs: ModelInputs):
84+
"""Get outputs."""
8385
raise NotImplementedError()
8486

8587
def prepare_inputs(self, model_inputs: ModelInputs, spec_inputs: SpecDecodeInputs):
8688
"""Prepare inputs."""
8789
raise NotImplementedError()
8890

91+
@record_function('draft_model_forward')
92+
def _forward(self, model_inputs: ModelInputs, cache_engine: CacheEngine = None, stream: torch.cuda.Stream = None):
93+
"""Forward."""
94+
return draft_model_forward(self.model,
95+
model_inputs,
96+
model_config=self.specdecode_config.model_config,
97+
cache_engine=cache_engine,
98+
stream=stream)
99+
89100
def update_inputs_decoding(self, model_inputs: ModelInputs, input_ids: torch.Tensor,
90101
target_hidden_states: torch.Tensor, model_metas: List[Any]):
91102
"""Update to decoding inputs."""
@@ -102,6 +113,7 @@ def update_inputs_decoding(self, model_inputs: ModelInputs, input_ids: torch.Ten
102113
model_inputs.target_hidden_states = target_hidden_states
103114
return model_inputs
104115

116+
@record_function('draft_get_logits')
105117
def get_logits(self, hidden_states: torch.Tensor):
106118
"""Get logits of model output."""
107119
draft_model = self.model
@@ -114,6 +126,10 @@ def get_logits(self, hidden_states: torch.Tensor):
114126
logits = self.target_model.get_logits(hidden_states)
115127
return logits
116128

129+
def get_target_hidden_size(self, model_config: ModelConfig):
130+
"""Get target hidden size."""
131+
return model_config.hidden_size
132+
117133

118134
def build_specdecode_proposer(specdecode_config: SpecDecodeConfig, device: str = 'cuda'):
119135
"""Build spec decoding proposer."""

0 commit comments

Comments
 (0)