Skip to content

Commit 465b533

Browse files
committed
support spec with tp and cudagraph
1 parent 420b5c4 commit 465b533

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+3918
-373
lines changed

benchmark/benchmark_serving.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,15 @@ def get_launching_server_cmd(model_path, backend, server_config):
1313
elif backend == 'sglang':
1414
cmd = ['python3', '-m', 'sglang.launch_server', '--model-path', model_path]
1515
elif backend == 'vllm':
16-
cmd = ['vllm', 'serve', '--model', model_path]
16+
cmd = ['vllm', 'serve', model_path]
1717
else:
1818
raise ValueError(f'unknown backend: {backend}')
1919
for key, value in server_config.items():
2020
# Convert snake_case to kebab-case for command line args
2121
key = key.replace('_', '-')
2222
cmd.append(f'--{key}')
23-
cmd.append(str(value))
23+
if str(value):
24+
cmd.append(str(value))
2425
# Special handling for proxy server case
2526
if server_config.get('proxy_url') and server_config.get('dp'):
2627
cmd.append('--allow-terminate-by-client')
@@ -66,9 +67,9 @@ def get_server_ip_port(backend: str, server_config: Dict) -> Tuple[str, int]:
6667
server_ip = server_config.get('server_ip', '0.0.0.0')
6768
server_port = server_config.get('server_port', 23333)
6869
elif backend == 'sglang':
69-
return (server_config.get('server_ip', '0.0.0.0'), server_config.get('server_port', 30000))
70+
return (server_config.get('server_ip', '0.0.0.0'), server_config.get('port', 30000))
7071
elif backend == 'vllm':
71-
return (server_config.get('server_ip', '0.0.0.0'), server_config.get('server_port', 8000))
72+
return (server_config.get('server_ip', '0.0.0.0'), server_config.get('port', 8000))
7273
else:
7374
raise ValueError(f'unknown backend: {backend}')
7475
return server_ip, server_port
@@ -131,7 +132,7 @@ def benchmark(model_path: str, backend: str, server_config: Dict, data_config: D
131132

132133
try:
133134

134-
print(f"Starting api_server: {' '.join(server_cmd)}")
135+
print(f"Starting api_server: {' '.join(server_cmd)}", flush=True)
135136
proc = subprocess.Popen(server_cmd)
136137
# Wait for the server to be ready
137138
wait_server_ready(server_ip, server_port)

lmdeploy/api.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import List, Literal, Optional, Union
44

55
from .archs import autoget_backend_config, get_task
6-
from .messages import PytorchEngineConfig, TurbomindEngineConfig
6+
from .messages import PytorchEngineConfig, SpeculativeConfig, TurbomindEngineConfig
77
from .model import ChatTemplateConfig
88

99

@@ -12,6 +12,7 @@ def pipeline(model_path: str,
1212
chat_template_config: Optional[ChatTemplateConfig] = None,
1313
log_level: str = 'WARNING',
1414
max_log_len: int = None,
15+
speculative_config: SpeculativeConfig = None,
1516
**kwargs):
1617
"""
1718
Args:
@@ -68,6 +69,12 @@ def pipeline(model_path: str,
6869
if backend_config is not None else None
6970
model_path = get_model(model_path, download_dir, revision)
7071

72+
# spec model
73+
if speculative_config is not None and speculative_config.model and not os.path.exists(speculative_config.model):
74+
download_dir = backend_config.download_dir \
75+
if backend_config is not None else None
76+
speculative_config.model = get_model(speculative_config.model, download_dir)
77+
7178
_, pipeline_class = get_task(model_path)
7279
if not isinstance(backend_config, PytorchEngineConfig):
7380
# set auto backend mode
@@ -80,6 +87,7 @@ def pipeline(model_path: str,
8087
backend_config=backend_config,
8188
chat_template_config=chat_template_config,
8289
max_log_len=max_log_len,
90+
speculative_config=speculative_config,
8391
**kwargs)
8492

8593

lmdeploy/cli/cli.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import os
44

55
from ..version import __version__
6-
from .utils import ArgumentHelper, DefaultsAndTypesHelpFormatter, FlexibleArgumentParser, convert_args
6+
from .utils import (ArgumentHelper, DefaultsAndTypesHelpFormatter, FlexibleArgumentParser, convert_args,
7+
get_speculative_config)
78

89

910
class CLI(object):
@@ -44,12 +45,13 @@ def add_parser_chat():
4445
', "baichuan-inc/baichuan2-7b-chat" and so on')
4546
# common args
4647
ArgumentHelper.backend(parser)
47-
# # chat template args
48+
ArgumentHelper.log_level(parser)
49+
# chat template args
4850
ArgumentHelper.chat_template(parser)
4951
# model args
5052
ArgumentHelper.revision(parser)
5153
ArgumentHelper.download_dir(parser)
52-
#
54+
5355
# pytorch engine args
5456
pt_group = parser.add_argument_group('PyTorch engine arguments')
5557
ArgumentHelper.adapters(pt_group)
@@ -77,6 +79,9 @@ def add_parser_chat():
7779
ArgumentHelper.rope_scaling_factor(tb_group)
7880
ArgumentHelper.communicator(tb_group)
7981

82+
# speculative decoding
83+
ArgumentHelper.add_spec_group(parser)
84+
8085
@staticmethod
8186
def add_parser_checkenv():
8287
"""Add parser for check_env command."""
@@ -168,7 +173,13 @@ def get_gpu_topo():
168173
@staticmethod
169174
def chat(args):
170175
from .chat import main
176+
171177
kwargs = convert_args(args)
178+
speculative_config = get_speculative_config(args)
179+
to_remove = ['speculative_algorithm', 'speculative_draft_model', 'speculative_num_draft_tokens']
180+
for key in to_remove:
181+
kwargs.pop(key)
182+
kwargs['speculative_config'] = speculative_config
172183
main(**kwargs)
173184

174185
@staticmethod

lmdeploy/cli/serve.py

Lines changed: 58 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
from lmdeploy.utils import get_max_batch_size
44

55
from .cli import CLI
6-
from .utils import ArgumentHelper, DefaultsAndTypesHelpFormatter, convert_args, get_chat_template, get_lora_adapters
6+
from .utils import (ArgumentHelper, DefaultsAndTypesHelpFormatter, convert_args, get_chat_template, get_lora_adapters,
7+
get_speculative_config)
78

89

910
class SubCliServe:
@@ -144,6 +145,9 @@ def add_parser_api_server():
144145
vision_group = parser.add_argument_group('Vision model arguments')
145146
ArgumentHelper.vision_max_batch_size(vision_group)
146147

148+
# spec decode
149+
ArgumentHelper.add_spec_group(parser)
150+
147151
@staticmethod
148152
def add_parser_proxy():
149153
"""Add parser for proxy server command."""
@@ -247,61 +251,68 @@ def api_server(args):
247251
enable_metrics=args.enable_metrics,
248252
hf_overrides=args.hf_overrides)
249253
chat_template_config = get_chat_template(args.chat_template, args.model_path)
254+
speculative_config = get_speculative_config(args)
250255

251256
from lmdeploy.messages import VisionConfig
252257
vision_config = VisionConfig(args.vision_max_batch_size)
253258
if args.dp == 1:
254259
from lmdeploy.serve.openai.api_server import serve as run_api_server
255260

256-
run_api_server(args.model_path,
257-
model_name=args.model_name,
258-
backend=backend,
259-
backend_config=backend_config,
260-
chat_template_config=chat_template_config,
261-
vision_config=vision_config,
262-
server_name=args.server_name,
263-
server_port=args.server_port,
264-
allow_origins=args.allow_origins,
265-
allow_credentials=args.allow_credentials,
266-
allow_methods=args.allow_methods,
267-
allow_headers=args.allow_headers,
268-
allow_terminate_by_client=args.allow_terminate_by_client,
269-
log_level=args.log_level.upper(),
270-
api_keys=args.api_keys,
271-
ssl=args.ssl,
272-
proxy_url=args.proxy_url,
273-
max_log_len=args.max_log_len,
274-
disable_fastapi_docs=args.disable_fastapi_docs,
275-
max_concurrent_requests=args.max_concurrent_requests,
276-
reasoning_parser=args.reasoning_parser,
277-
tool_call_parser=args.tool_call_parser)
261+
run_api_server(
262+
args.model_path,
263+
model_name=args.model_name,
264+
backend=backend,
265+
backend_config=backend_config,
266+
chat_template_config=chat_template_config,
267+
vision_config=vision_config,
268+
server_name=args.server_name,
269+
server_port=args.server_port,
270+
allow_origins=args.allow_origins,
271+
allow_credentials=args.allow_credentials,
272+
allow_methods=args.allow_methods,
273+
allow_headers=args.allow_headers,
274+
allow_terminate_by_client=args.allow_terminate_by_client,
275+
log_level=args.log_level.upper(),
276+
api_keys=args.api_keys,
277+
ssl=args.ssl,
278+
proxy_url=args.proxy_url,
279+
max_log_len=args.max_log_len,
280+
disable_fastapi_docs=args.disable_fastapi_docs,
281+
max_concurrent_requests=args.max_concurrent_requests,
282+
reasoning_parser=args.reasoning_parser,
283+
tool_call_parser=args.tool_call_parser,
284+
speculative_config=speculative_config,
285+
)
278286
else:
279287
from lmdeploy.serve.openai.launch_server import launch_server
280288

281-
launch_server(args.nnodes,
282-
args.node_rank,
283-
args.model_path,
284-
model_name=args.model_name,
285-
backend=backend,
286-
backend_config=backend_config,
287-
chat_template_config=chat_template_config,
288-
vision_config=vision_config,
289-
server_name=args.server_name,
290-
server_port=args.server_port,
291-
allow_origins=args.allow_origins,
292-
allow_credentials=args.allow_credentials,
293-
allow_methods=args.allow_methods,
294-
allow_headers=args.allow_headers,
295-
allow_terminate_by_client=args.allow_terminate_by_client,
296-
log_level=args.log_level.upper(),
297-
api_keys=args.api_keys,
298-
ssl=args.ssl,
299-
proxy_url=args.proxy_url,
300-
max_log_len=args.max_log_len,
301-
disable_fastapi_docs=args.disable_fastapi_docs,
302-
max_concurrent_requests=args.max_concurrent_requests,
303-
reasoning_parser=args.reasoning_parser,
304-
tool_call_parser=args.tool_call_parser)
289+
launch_server(
290+
args.nnodes,
291+
args.node_rank,
292+
args.model_path,
293+
model_name=args.model_name,
294+
backend=backend,
295+
backend_config=backend_config,
296+
chat_template_config=chat_template_config,
297+
vision_config=vision_config,
298+
server_name=args.server_name,
299+
server_port=args.server_port,
300+
allow_origins=args.allow_origins,
301+
allow_credentials=args.allow_credentials,
302+
allow_methods=args.allow_methods,
303+
allow_headers=args.allow_headers,
304+
allow_terminate_by_client=args.allow_terminate_by_client,
305+
log_level=args.log_level.upper(),
306+
api_keys=args.api_keys,
307+
ssl=args.ssl,
308+
proxy_url=args.proxy_url,
309+
max_log_len=args.max_log_len,
310+
disable_fastapi_docs=args.disable_fastapi_docs,
311+
max_concurrent_requests=args.max_concurrent_requests,
312+
reasoning_parser=args.reasoning_parser,
313+
tool_call_parser=args.tool_call_parser,
314+
speculative_config=speculative_config,
315+
)
305316

306317
@staticmethod
307318
def proxy(args):

lmdeploy/cli/utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,19 @@ def get_chat_template(chat_template: str, model_path: str = None):
9999
return None
100100

101101

102+
def get_speculative_config(args):
103+
"""Get speculative config from args."""
104+
from lmdeploy.messages import SpeculativeConfig
105+
speculative_config = None
106+
if args.speculative_algorithm is not None:
107+
speculative_config = SpeculativeConfig(
108+
method=args.speculative_algorithm,
109+
model=args.speculative_draft_model,
110+
num_speculative_tokens=args.speculative_num_draft_tokens,
111+
)
112+
return speculative_config
113+
114+
102115
class ArgumentHelper:
103116
"""Helper class to add unified argument."""
104117

@@ -654,6 +667,26 @@ def dllm_confidence_threshold(parser):
654667
default=0.85,
655668
help='The confidence threshold for dllm.')
656669

670+
def add_spec_group(parser):
671+
spec_group = parser.add_argument_group('Speculative decoding arguments')
672+
spec_group.add_argument('--speculative-algorithm',
673+
type=str,
674+
default=None,
675+
choices=['eagle', 'eagle3', 'deepseek_mtp'],
676+
help='The speculative algorithm to use. `None` means speculative decoding is disabled')
677+
678+
spec_group.add_argument('--speculative-draft-model',
679+
type=str,
680+
default=None,
681+
help='The path to speculative draft model')
682+
683+
spec_group.add_argument('--speculative-num-draft-tokens',
684+
type=int,
685+
default=1,
686+
help='The number of speculative tokens to generate per step')
687+
688+
return spec_group
689+
657690

658691
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/utils/__init__.py
659692
class FlexibleArgumentParser(argparse.ArgumentParser):

lmdeploy/messages.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,7 @@ class RequestMetrics:
522522
"""
523523
token_timestamp: float = 0.0
524524
engine_events: List[EngineEvent] = field(default_factory=list)
525+
spec_info: Optional[Dict[str, Any]] = None
525526

526527

527528
@dataclass
@@ -562,3 +563,17 @@ class VisionConfig:
562563
"""
563564
max_batch_size: int = 1
564565
thread_safe: bool = False
566+
567+
568+
@dataclass
569+
class SpeculativeConfig:
570+
"""Speculative decoding config.
571+
572+
Args:
573+
method (str): the speculative decoding method.
574+
model (str): the path of speculative model.
575+
num_speculative_tokens (int): number of generated token of draft model per step
576+
"""
577+
method: str
578+
model: str = ''
579+
num_speculative_tokens: int = 1

0 commit comments

Comments
 (0)