Skip to content

Commit 343eede

Browse files
authored
[Feature]: return routed experts to reuse (#4090)
* return routed experts * only on tp rank0 * no change to model config * clone experts from cudagraph buffer * use ray to transfer experts * only put data to ray obj store when use api server * enable ray transfer for apiserver * resolve comments * resolve comments
1 parent bf52c04 commit 343eede

File tree

20 files changed

+343
-72
lines changed

20 files changed

+343
-72
lines changed

benchmark/profile_pipeline_api.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ class Engine:
134134
def __init__(self, model_path: str, engine_config, csv: str):
135135
self.pipe = pipeline(model_path, backend_config=engine_config, log_level='ERROR')
136136
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
137-
137+
self.return_routed_experts = getattr(self.pipe.backend_config, 'enable_return_routed_experts', False)
138138
self.csv = csv
139139

140140
def process_request(self, requests, profiler: Profiler, temperature, top_p, top_k, stream_output):
@@ -146,6 +146,7 @@ def process_request(self, requests, profiler: Profiler, temperature, top_p, top_
146146
top_k=top_k,
147147
ignore_eos=True,
148148
do_sample=False,
149+
return_routed_experts=self.return_routed_experts,
149150
max_new_tokens=output_len) for _, _, output_len in requests
150151
]
151152

@@ -254,6 +255,7 @@ def parse_args():
254255
# pytorch engine args
255256
pt_group = parser.add_argument_group('PyTorch engine arguments')
256257
ArgumentHelper.eager_mode(pt_group)
258+
ArgumentHelper.enable_return_routed_experts(pt_group)
257259

258260
tp_act = ArgumentHelper.tp(pt_group)
259261
cache_count_act = ArgumentHelper.cache_max_entry_count(pt_group)
@@ -302,6 +304,7 @@ def main():
302304
thread_safe=False,
303305
eager_mode=args.eager_mode,
304306
enable_prefix_caching=args.enable_prefix_caching,
307+
enable_return_routed_experts=args.enable_return_routed_experts,
305308
)
306309

307310
engine = Engine(args.model_path, engine_config, csv=args.csv)

lmdeploy/cli/serve.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def add_parser_api_server():
9797
ArgumentHelper.dllm_unmasking_strategy(pt_group)
9898
ArgumentHelper.dllm_denoising_steps(pt_group)
9999
ArgumentHelper.dllm_confidence_threshold(pt_group)
100+
ArgumentHelper.enable_return_routed_experts(pt_group)
100101

101102
# common engine args
102103
dtype_act = ArgumentHelper.dtype(pt_group)
@@ -228,6 +229,7 @@ def api_server(args):
228229
dllm_unmasking_strategy=args.dllm_unmasking_strategy,
229230
dllm_denoising_steps=args.dllm_denoising_steps,
230231
dllm_confidence_threshold=args.dllm_confidence_threshold,
232+
enable_return_routed_experts=args.enable_return_routed_experts,
231233
)
232234
else:
233235
from lmdeploy.messages import TurbomindEngineConfig

lmdeploy/cli/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -667,6 +667,15 @@ def dllm_confidence_threshold(parser):
667667
default=0.85,
668668
help='The confidence threshold for dllm.')
669669

670+
@staticmethod
671+
def enable_return_routed_experts(parser):
672+
"""Add argument return routed experts to parser."""
673+
674+
return parser.add_argument('--enable-return-routed-experts',
675+
action='store_true',
676+
default=False,
677+
help='Whether to output routed expert ids for replay')
678+
670679

671680
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/utils/__init__.py
672681
class FlexibleArgumentParser(argparse.ArgumentParser):

lmdeploy/messages.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,9 @@ class GenerationConfig:
117117
preserve_cache: bool = False
118118
migration_request: Optional[MigrationRequest] = None
119119

120+
# router replay
121+
return_routed_experts: bool = False
122+
120123
def convert_stop_bad_words_to_ids(self, tokenizer: Tokenizer):
121124
"""Convert stop_words/bad_sords to ids and append the ids to
122125
stop_token_ids/bad_token_ids."""
@@ -376,6 +379,9 @@ class PytorchEngineConfig:
376379
hf_overrides: Optional[Dict[str, Any]] = None
377380
disable_vision_encoder: bool = False
378381
logprobs_mode: str = None
382+
# router replay
383+
enable_return_routed_experts: bool = False
384+
enable_transfer_obj_ref: bool = False
379385

380386
# dllm
381387
dllm_block_length: int = None
@@ -457,14 +463,18 @@ class Response:
457463
logits: torch.Tensor = None
458464
last_hidden_state: torch.Tensor = None
459465
index: int = 0
466+
routed_experts: Any = None
460467

461468
def __repr__(self):
462469
logits = 'logits=None' if self.logits is None else f'logits.shape={self.logits.shape}\nlogits={self.logits}'
463470
hidden_state = (
464471
'last_hidden_state=None' if self.last_hidden_state is None else
465472
f'last_hidden_state.shape={self.last_hidden_state.shape}\nlast_hidden_state={self.last_hidden_state}')
466-
s = (f'text={self.text}\ngenerate_token_len={self.generate_token_len}\nfinish_reason="{self.finish_reason}"\n'
467-
f'token_ids={self.token_ids}\nlog_probs={self.logprobs}\n{logits}\n{hidden_state}')
473+
routed_experts = 'routed_experts=None' if self.routed_experts is None else \
474+
f'routed_experts.shape={self.routed_experts.shape}'
475+
476+
s = (f'text={self.text!r}\ngenerate_token_len={self.generate_token_len}\nfinish_reason="{self.finish_reason}"\n'
477+
f'token_ids={self.token_ids}\nlog_probs={self.logprobs}\n{logits}\n{hidden_state}\n{routed_experts}')
468478
return s
469479

470480

@@ -544,6 +554,7 @@ class EngineOutput:
544554
last_hidden_state: torch.Tensor = None
545555
cache_block_ids: Optional[List[int]] = None
546556
req_metrics: Optional[RequestMetrics] = None
557+
routed_experts: torch.Tensor = None
547558

548559

549560
@dataclass

lmdeploy/pytorch/backends/cuda/graph_runner.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -91,16 +91,6 @@ def __init__(
9191
self.pool = pool
9292
self._graph: torch.cuda.CUDAGraph = None
9393

94-
def make_output_buffers(self, output):
95-
"""Make output buffers."""
96-
output_buffers = dict(logits=output)
97-
return output_buffers
98-
99-
def slice_output(self, output_buffers: Dict[str, Any], inputs: Dict[str, Any]):
100-
"""Slice output."""
101-
num_tokens = inputs['input_ids'].size(-1)
102-
return output_buffers['logits'][:, :num_tokens]
103-
10494
@record_function('capture_cudagraph')
10595
def capture(self, **kwargs):
10696
"""Capture graph."""
@@ -113,17 +103,17 @@ def capture(self, **kwargs):
113103

114104
# warmup
115105
warmup_output = self.model(**padded_kwargs)
116-
warmup_buffers = self.make_output_buffers(warmup_output)
106+
warmup_buffers = self.model.make_output_buffers(warmup_output)
117107

118108
self._graph = torch.cuda.CUDAGraph()
119109
# unsafe kernel call in other thread might invalid the capture
120110
# so we set thread_safe capture mode here.
121111
with torch.cuda.graph(self._graph, pool=self.pool, stream=current_stream, capture_error_mode='thread_local'):
122112
output = self.model(**padded_kwargs)
123113

124-
output_buffers = self.make_output_buffers(output)
114+
output_buffers = self.model.make_output_buffers(output)
125115
self.meta.output_buffers = output_buffers
126-
output = self.slice_output(warmup_buffers, kwargs)
116+
output = self.model.get_outputs_cudagraph(warmup_buffers, **kwargs)
127117
return output
128118

129119
@record_function('forward_cudagraph')
@@ -134,9 +124,8 @@ def forward(self, **kwargs):
134124
context = self.ctx_mgr.current_context()
135125
self.model.update_context_cudagraph(self.meta, context)
136126
self._graph.replay()
137-
138127
output_buffers = self.meta.output_buffers
139-
output = self.slice_output(output_buffers, kwargs)
128+
output = self.model.get_outputs_cudagraph(output_buffers, **kwargs)
140129
return output
141130

142131
def __del__(self):
@@ -220,7 +209,8 @@ def __call__(self, **kwargs):
220209

221210
if not enable_graph:
222211
with record_function('forward_eager'):
223-
return self.model(**kwargs)
212+
output = self.model(**kwargs)
213+
return self.model.make_output_buffers(output)
224214

225215
graph_key = self.get_graph_key(**kwargs)
226216
max_batches = graph_key[0]

lmdeploy/pytorch/config.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,7 @@ class MiscConfig:
347347
disable_vision_encoder: bool = False
348348
logprobs_mode: str = None
349349
dllm_config: DLLMConfig = None
350+
enable_return_routed_experts: bool = False
350351

351352
@classmethod
352353
def from_engine_config(cls, engine_config: PytorchEngineConfig):
@@ -356,12 +357,15 @@ def from_engine_config(cls, engine_config: PytorchEngineConfig):
356357
unmasking_strategy=dllm_unmasking_strategy,
357358
denoising_steps=engine_config.dllm_denoising_steps,
358359
confidence_threshold=engine_config.dllm_confidence_threshold)
359-
misc_config = cls(custom_module_map=engine_config.custom_module_map,
360-
empty_init=engine_config.empty_init,
361-
prefill_interval=engine_config.prefill_interval,
362-
model_format=engine_config.model_format,
363-
hf_overrides=engine_config.hf_overrides,
364-
disable_vision_encoder=engine_config.disable_vision_encoder,
365-
logprobs_mode=engine_config.logprobs_mode,
366-
dllm_config=dllm_config)
360+
misc_config = cls(
361+
custom_module_map=engine_config.custom_module_map,
362+
empty_init=engine_config.empty_init,
363+
prefill_interval=engine_config.prefill_interval,
364+
model_format=engine_config.model_format,
365+
hf_overrides=engine_config.hf_overrides,
366+
disable_vision_encoder=engine_config.disable_vision_encoder,
367+
logprobs_mode=engine_config.logprobs_mode,
368+
dllm_config=dllm_config,
369+
enable_return_routed_experts=engine_config.enable_return_routed_experts,
370+
)
367371
return misc_config

lmdeploy/pytorch/engine/engine.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ class InferOutput:
5757
# for logging
5858
req_metrics: RequestMetrics = None
5959

60+
# expert ids
61+
routed_experts: torch.Tensor = None
62+
6063

6164
def _tensorlize_block_offsets(block_offsets, dtype=torch.int32):
6265
"""Tensorlize block_offsets."""
@@ -876,13 +879,18 @@ def _make_infer_outputs(
876879
cur_logprobs = (logprobs.vals[idx][:num_logprobs + 1], logprobs.indices[idx][:num_logprobs + 1])
877880

878881
req_metrics = RequestMetrics(new_token_timestamp, msg.engine_events)
882+
routed_experts = msg.routed_experts if msg.return_routed_experts and finish else None
883+
if routed_experts is not None and self.engine_config.enable_transfer_obj_ref:
884+
# only serialize for api server
885+
routed_experts = self.executor.serialize(routed_experts)
879886
out = InferOutput(session_id=session_id,
880887
resp=msg.resp,
881888
finish=finish,
882889
token_ids=token_ids,
883890
cache_block_ids=cache_block_ids,
884891
req_metrics=req_metrics,
885-
logprobs=cur_logprobs)
892+
logprobs=cur_logprobs,
893+
routed_experts=routed_experts)
886894
outputs[session_id] = out
887895

888896
if msg.return_logits:
@@ -896,6 +904,10 @@ def __need_logits(seqs: SeqList):
896904
"""Need logits."""
897905
return any(seq.return_logits for seq in seqs)
898906

907+
def __need_routed_experts(seqs: SeqList):
908+
"""Need routed experts."""
909+
return any(seq.return_routed_experts for seq in seqs)
910+
899911
def __need_schedule_again(prefill: bool, scheduler_output):
900912
"""Need schedule again."""
901913
# only reschedule when prefill
@@ -939,6 +951,7 @@ def __need_schedule_again(prefill: bool, scheduler_output):
939951
inputs = self.create_model_inputs(running, prefill)
940952
sampling_inputs = self.sampling_strategy.make_sampling_inputs(running)
941953
return_logits = __need_logits(running)
954+
return_routed_experts = __need_routed_experts(running)
942955
extra_inputs = self.model_agent_strategy.make_extra_inputs(running)
943956
stopping_criteria = self.model_agent_strategy.make_stopping_criteria(running)
944957

@@ -956,6 +969,7 @@ def __need_schedule_again(prefill: bool, scheduler_output):
956969
is_dummy=False,
957970
sync_long_context=sync_long_context,
958971
extra_inputs=extra_inputs,
972+
return_routed_experts=return_routed_experts,
959973
)
960974

961975
async def _await_forward_event(self, forward_event: asyncio.Event):
@@ -991,6 +1005,7 @@ def __send_resp(out: InferOutput):
9911005
logits=out.logits,
9921006
cache_block_ids=out.cache_block_ids,
9931007
req_metrics=out.req_metrics,
1008+
routed_experts=out.routed_experts,
9941009
logprobs=logprobs))
9951010

9961011
def __update_logprobs(step_outputs: List[InferOutput]):

lmdeploy/pytorch/engine/engine_instance.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,8 @@ async def async_stream_infer(self,
152152
cache_block_ids = resp.data.get('cache_block_ids', None) if resp.data else None
153153
req_metrics = resp.data.get('req_metrics', None) if resp.data else None
154154
logprobs = resp.data.pop('logprobs', None) if resp.data else None
155+
routed_experts = resp.data.get('routed_experts', None) if resp.data else None
156+
155157
if resp.type == ResponseType.SUCCESS:
156158
token_ids = resp.data['token_ids'].tolist()
157159
num_ids = len(token_ids) - output_offset
@@ -160,6 +162,7 @@ async def async_stream_infer(self,
160162
token_ids[output_offset:],
161163
cache_block_ids=cache_block_ids,
162164
req_metrics=req_metrics,
165+
routed_experts=routed_experts,
163166
logprobs=logprobs)
164167
output_offset = len(token_ids)
165168
elif resp.type == ResponseType.FINISH:
@@ -173,6 +176,7 @@ async def async_stream_infer(self,
173176
logits=logits,
174177
cache_block_ids=cache_block_ids,
175178
req_metrics=req_metrics,
179+
routed_experts=routed_experts,
176180
logprobs=logprobs)
177181
break
178182
else:

lmdeploy/pytorch/engine/executor/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,10 @@ def release(self):
101101
"""Release resources."""
102102
raise NotImplementedError('Not Implemented.')
103103

104+
def serialize(self, obj):
105+
"""Serialize obj."""
106+
return obj
107+
104108
async def forward_async(self, inputs):
105109
"""Start forward."""
106110
raise NotImplementedError('Not Implemented')

lmdeploy/pytorch/engine/executor/ray_executor.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import asyncio
3+
import base64
34
import contextlib
45
import json
56
import os
@@ -351,6 +352,13 @@ def wakeup(self, tags: Optional[List[str]] = None):
351352
self.update_configs()
352353
self.collective_rpc('wakeup', (tags, ))
353354

355+
def serialize(self, obj) -> str:
356+
"""Serialize obj."""
357+
ref = ray.put(obj)
358+
data = ray.cloudpickle.dumps(ref)
359+
data = base64.b64encode(data).decode('utf-8')
360+
return data
361+
354362
def get_input_processor(self):
355363
"""Build cache engine."""
356364
return ray.get(self.workers[0].get_input_processor.remote())

0 commit comments

Comments
 (0)