Skip to content

Commit c7282c4

Browse files
committed
support eagle3,mtp with cudagraph and tp and long text input
1 parent 7bf9265 commit c7282c4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+3069
-330
lines changed

benchmark/benchmark_serving.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def get_launching_server_cmd(model_path, backend, server_config):
1313
elif backend == 'sglang':
1414
cmd = ['python3', '-m', 'sglang.launch_server', '--model-path', model_path]
1515
elif backend == 'vllm':
16-
cmd = ['vllm', 'serve', '--model', model_path]
16+
cmd = ['vllm', 'serve', model_path]
1717
else:
1818
raise ValueError(f'unknown backend: {backend}')
1919
for key, value in server_config.items():
@@ -131,7 +131,7 @@ def benchmark(model_path: str, backend: str, server_config: Dict, data_config: D
131131

132132
try:
133133

134-
print(f"Starting api_server: {' '.join(server_cmd)}")
134+
print(f"Starting api_server: {' '.join(server_cmd)}", flush=True)
135135
proc = subprocess.Popen(server_cmd)
136136
# Wait for the server to be ready
137137
wait_server_ready(server_ip, server_port)

lmdeploy/api.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import List, Literal, Optional, Union
44

55
from .archs import autoget_backend_config, get_task
6-
from .messages import PytorchEngineConfig, TurbomindEngineConfig
6+
from .messages import PytorchEngineConfig, SpeculativeConfig, TurbomindEngineConfig
77
from .model import ChatTemplateConfig
88

99

@@ -12,6 +12,7 @@ def pipeline(model_path: str,
1212
chat_template_config: Optional[ChatTemplateConfig] = None,
1313
log_level: str = 'WARNING',
1414
max_log_len: int = None,
15+
speculative_config: SpeculativeConfig = None,
1516
**kwargs):
1617
"""
1718
Args:
@@ -68,6 +69,12 @@ def pipeline(model_path: str,
6869
if backend_config is not None else None
6970
model_path = get_model(model_path, download_dir, revision)
7071

72+
# spec model
73+
if speculative_config is not None and speculative_config.model and not os.path.exists(speculative_config.model):
74+
download_dir = backend_config.download_dir \
75+
if backend_config is not None else None
76+
speculative_config.model = get_model(speculative_config.model, download_dir)
77+
7178
task, pipeline_class = get_task(model_path)
7279
if task == 'vlm':
7380
if backend_config and backend_config.enable_prefix_caching:
@@ -85,6 +92,7 @@ def pipeline(model_path: str,
8592
backend_config=backend_config,
8693
chat_template_config=chat_template_config,
8794
max_log_len=max_log_len,
95+
speculative_config=speculative_config,
8896
**kwargs)
8997

9098

lmdeploy/cli/cli.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44

55
from ..version import __version__
6-
from .utils import ArgumentHelper, DefaultsAndTypesHelpFormatter, FlexibleArgumentParser, convert_args
6+
from .utils import ArgumentHelper, DefaultsAndTypesHelpFormatter, FlexibleArgumentParser, convert_args, get_speculative_config
77

88

99
class CLI(object):
@@ -44,12 +44,13 @@ def add_parser_chat():
4444
', "baichuan-inc/baichuan2-7b-chat" and so on')
4545
# common args
4646
ArgumentHelper.backend(parser)
47-
# # chat template args
47+
ArgumentHelper.log_level(parser)
48+
# chat template args
4849
ArgumentHelper.chat_template(parser)
4950
# model args
5051
ArgumentHelper.revision(parser)
5152
ArgumentHelper.download_dir(parser)
52-
#
53+
5354
# pytorch engine args
5455
pt_group = parser.add_argument_group('PyTorch engine arguments')
5556
ArgumentHelper.adapters(pt_group)
@@ -76,6 +77,9 @@ def add_parser_chat():
7677
ArgumentHelper.rope_scaling_factor(tb_group)
7778
ArgumentHelper.communicator(tb_group)
7879

80+
# speculative decoding
81+
ArgumentHelper.add_spec_group(parser)
82+
7983
@staticmethod
8084
def add_parser_checkenv():
8185
"""Add parser for check_env command."""
@@ -167,7 +171,13 @@ def get_gpu_topo():
167171
@staticmethod
168172
def chat(args):
169173
from .chat import main
174+
170175
kwargs = convert_args(args)
176+
speculative_config = get_speculative_config(args)
177+
to_remove = ['speculative_algorithm','speculative_draft_model','speculative_num_draft_tokens']
178+
for key in to_remove:
179+
kwargs.pop(key)
180+
kwargs['speculative_config'] = speculative_config
171181
main(**kwargs)
172182

173183
@staticmethod

lmdeploy/cli/serve.py

Lines changed: 58 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
from lmdeploy.utils import get_max_batch_size
44

55
from .cli import CLI
6-
from .utils import ArgumentHelper, DefaultsAndTypesHelpFormatter, convert_args, get_chat_template, get_lora_adapters
6+
from .utils import (ArgumentHelper, DefaultsAndTypesHelpFormatter, convert_args, get_chat_template, get_lora_adapters,
7+
get_speculative_config)
78

89

910
class SubCliServe:
@@ -140,6 +141,9 @@ def add_parser_api_server():
140141
vision_group = parser.add_argument_group('Vision model arguments')
141142
ArgumentHelper.vision_max_batch_size(vision_group)
142143

144+
# spec decode
145+
ArgumentHelper.add_spec_group(parser)
146+
143147
@staticmethod
144148
def add_parser_proxy():
145149
"""Add parser for proxy server command."""
@@ -239,61 +243,68 @@ def api_server(args):
239243
enable_metrics=args.enable_metrics,
240244
hf_overrides=args.hf_overrides)
241245
chat_template_config = get_chat_template(args.chat_template)
246+
speculative_config = get_speculative_config(args)
242247

243248
from lmdeploy.messages import VisionConfig
244249
vision_config = VisionConfig(args.vision_max_batch_size)
245250
if args.dp == 1:
246251
from lmdeploy.serve.openai.api_server import serve as run_api_server
247252

248-
run_api_server(args.model_path,
249-
model_name=args.model_name,
250-
backend=backend,
251-
backend_config=backend_config,
252-
chat_template_config=chat_template_config,
253-
vision_config=vision_config,
254-
server_name=args.server_name,
255-
server_port=args.server_port,
256-
allow_origins=args.allow_origins,
257-
allow_credentials=args.allow_credentials,
258-
allow_methods=args.allow_methods,
259-
allow_headers=args.allow_headers,
260-
allow_terminate_by_client=args.allow_terminate_by_client,
261-
log_level=args.log_level.upper(),
262-
api_keys=args.api_keys,
263-
ssl=args.ssl,
264-
proxy_url=args.proxy_url,
265-
max_log_len=args.max_log_len,
266-
disable_fastapi_docs=args.disable_fastapi_docs,
267-
max_concurrent_requests=args.max_concurrent_requests,
268-
reasoning_parser=args.reasoning_parser,
269-
tool_call_parser=args.tool_call_parser)
253+
run_api_server(
254+
args.model_path,
255+
model_name=args.model_name,
256+
backend=backend,
257+
backend_config=backend_config,
258+
chat_template_config=chat_template_config,
259+
vision_config=vision_config,
260+
server_name=args.server_name,
261+
server_port=args.server_port,
262+
allow_origins=args.allow_origins,
263+
allow_credentials=args.allow_credentials,
264+
allow_methods=args.allow_methods,
265+
allow_headers=args.allow_headers,
266+
allow_terminate_by_client=args.allow_terminate_by_client,
267+
log_level=args.log_level.upper(),
268+
api_keys=args.api_keys,
269+
ssl=args.ssl,
270+
proxy_url=args.proxy_url,
271+
max_log_len=args.max_log_len,
272+
disable_fastapi_docs=args.disable_fastapi_docs,
273+
max_concurrent_requests=args.max_concurrent_requests,
274+
reasoning_parser=args.reasoning_parser,
275+
tool_call_parser=args.tool_call_parser,
276+
speculative_config=speculative_config,
277+
)
270278
else:
271279
from lmdeploy.serve.openai.launch_server import launch_server
272280

273-
launch_server(args.nnodes,
274-
args.node_rank,
275-
args.model_path,
276-
model_name=args.model_name,
277-
backend=backend,
278-
backend_config=backend_config,
279-
chat_template_config=chat_template_config,
280-
vision_config=vision_config,
281-
server_name=args.server_name,
282-
server_port=args.server_port,
283-
allow_origins=args.allow_origins,
284-
allow_credentials=args.allow_credentials,
285-
allow_methods=args.allow_methods,
286-
allow_headers=args.allow_headers,
287-
allow_terminate_by_client=args.allow_terminate_by_client,
288-
log_level=args.log_level.upper(),
289-
api_keys=args.api_keys,
290-
ssl=args.ssl,
291-
proxy_url=args.proxy_url,
292-
max_log_len=args.max_log_len,
293-
disable_fastapi_docs=args.disable_fastapi_docs,
294-
max_concurrent_requests=args.max_concurrent_requests,
295-
reasoning_parser=args.reasoning_parser,
296-
tool_call_parser=args.tool_call_parser)
281+
launch_server(
282+
args.nnodes,
283+
args.node_rank,
284+
args.model_path,
285+
model_name=args.model_name,
286+
backend=backend,
287+
backend_config=backend_config,
288+
chat_template_config=chat_template_config,
289+
vision_config=vision_config,
290+
server_name=args.server_name,
291+
server_port=args.server_port,
292+
allow_origins=args.allow_origins,
293+
allow_credentials=args.allow_credentials,
294+
allow_methods=args.allow_methods,
295+
allow_headers=args.allow_headers,
296+
allow_terminate_by_client=args.allow_terminate_by_client,
297+
log_level=args.log_level.upper(),
298+
api_keys=args.api_keys,
299+
ssl=args.ssl,
300+
proxy_url=args.proxy_url,
301+
max_log_len=args.max_log_len,
302+
disable_fastapi_docs=args.disable_fastapi_docs,
303+
max_concurrent_requests=args.max_concurrent_requests,
304+
reasoning_parser=args.reasoning_parser,
305+
tool_call_parser=args.tool_call_parser,
306+
speculative_config=speculative_config,
307+
)
297308

298309
@staticmethod
299310
def proxy(args):

lmdeploy/cli/utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,19 @@ def get_chat_template(chat_template: str):
8686
return None
8787

8888

89+
def get_speculative_config(args):
90+
"""Get speculative config from args."""
91+
from lmdeploy.messages import SpeculativeConfig
92+
speculative_config = None
93+
if args.speculative_algorithm is not None:
94+
speculative_config = SpeculativeConfig(
95+
method=args.speculative_algorithm,
96+
model=args.speculative_draft_model,
97+
num_speculative_tokens=args.speculative_num_draft_tokens,
98+
)
99+
return speculative_config
100+
101+
89102
class ArgumentHelper:
90103
"""Helper class to add unified argument."""
91104

@@ -610,6 +623,26 @@ def logprobs_mode(parser):
610623
choices=[None, 'raw_logits', 'raw_logprobs'],
611624
help='The mode of logprobs.')
612625

626+
def add_spec_group(parser):
627+
spec_group = parser.add_argument_group('Speculative decoding arguments')
628+
spec_group.add_argument('--speculative-algorithm',
629+
type=str,
630+
default=None,
631+
choices=['eagle', 'eagle3', 'deepseek_mtp'],
632+
help='The speculative algorithm to use. `None` means speculative decoding is disabled')
633+
634+
spec_group.add_argument('--speculative-draft-model',
635+
type=str,
636+
default=None,
637+
help='The path to speculative draft model')
638+
639+
spec_group.add_argument('--speculative-num-draft-tokens',
640+
type=int,
641+
default=1,
642+
help='The number of speculative tokens to generate per step')
643+
644+
return spec_group
645+
613646

614647
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/utils/__init__.py
615648
class FlexibleArgumentParser(argparse.ArgumentParser):

lmdeploy/messages.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,7 @@ class RequestMetrics:
509509
"""
510510
token_timestamp: float = 0.0
511511
engine_events: List[EngineEvent] = field(default_factory=list)
512+
spec_info: Optional[Dict[str, Any]] = None
512513

513514

514515
@dataclass
@@ -549,3 +550,17 @@ class VisionConfig:
549550
"""
550551
max_batch_size: int = 1
551552
thread_safe: bool = False
553+
554+
555+
@dataclass
556+
class SpeculativeConfig:
557+
"""Speculative decoding config.
558+
559+
Args:
560+
method (str): the speculative decoding method.
561+
model (str): the path of speculative model.
562+
num_speculative_tokens (int): number of generated token of draft model per step
563+
"""
564+
method: str
565+
model: str = ''
566+
num_speculative_tokens: int = 1

0 commit comments

Comments
 (0)