diff --git a/PR_DESCRIPTION.md b/PR_DESCRIPTION.md new file mode 100644 index 0000000000..31bb70b14b --- /dev/null +++ b/PR_DESCRIPTION.md @@ -0,0 +1,100 @@ +## Motivation + +In DLLM (Disaggregated LLM) mode, tokens are generated in blocks and progressively unmasked in a non-sequential order. Currently, there is no way to track which decoding step each token was revealed in. This information is valuable for: +- Analyzing DLLM decoding efficiency +- Understanding the non-sequential token generation pattern +- Debugging and optimizing unmasking strategies +- Research on speculative decoding performance + +This PR adds a `step_map` field to track the decoding step number for each generated token. + +## Modification + +This PR introduces a `step_map` feature that records which step each token was decoded in DLLM mode: + +1. **Core tracking logic** (`lmdeploy/pytorch/strategies/dllm/sequence.py`): + - Added `history_step_map` field to `SchedulerSequenceDLLM` to store step numbers + - Added `_current_step` counter to track decoding steps + - Added `step_map` and `generated_step_map` properties + - Updated `_update_token_ids_decode()` to record step numbers when tokens transition from MASKED to UNMASKED + - Step counter only increments when new tokens are actually unmasked + +2. **Engine layer** (`lmdeploy/pytorch/engine/engine.py`): + - Added `step_map` field to `InferOutput` dataclass + - Extract `step_map` from messages in `_make_infer_outputs()` + - Propagate `step_map` through response data + +3. **Instance layer** (`lmdeploy/pytorch/engine/engine_instance.py`): + - Extract and pass `step_map` to `EngineOutput` + +4. **API layer** (`lmdeploy/messages.py`): + - Added `step_map` field to `Response` dataclass + - Added `step_map` field to `EngineOutput` dataclass + - Updated `Response.__repr__()` to display step_map + +5. **Async engine layer** (`lmdeploy/serve/async_engine.py`): + - Added `step_map` field to `GenOut` dataclass + - Updated `_gen_out_to_response()` to pass step_map + - Updated `_append_response()` to accumulate step_map across iterations + - Extract incremental step_map from engine outputs in generation loop + +**How it works:** +- Each token gets a step number indicating when it was unmasked (1, 2, 3, ...) +- The step_map array has the same length as the generated tokens +- Non-sequential order in step_map reflects DLLM's parallel decoding behavior + +## BC-breaking (Optional) + +**No breaking changes.** This is a backward-compatible addition: +- New optional `step_map` field defaults to `None` in all dataclasses +- Existing code will continue to work without modification +- Only DLLM mode populates step_map; other modes return `None` + +## Use cases (Optional) + +**Example usage:** + +```python +from lmdeploy import pipeline, PytorchEngineConfig, GenerationConfig + +# Configure DLLM +backend_config = PytorchEngineConfig( + dllm_block_length=4, + dllm_unmasking_strategy="low_confidence_dynamic", +) + +with pipeline(model_path, backend_config=backend_config) as pipe: + gen_config = GenerationConfig(max_new_tokens=100) + outputs = pipe(["Hello"], gen_config=gen_config) + + for output in outputs: + if output.step_map is not None: + print(f"Tokens: {output.token_ids}") + print(f"Step map: {output.step_map}") + # Example: [1, 2, 1, 1, 3, 3, 3, 3, ...] + # Shows non-sequential unmasking pattern +``` + +**Analysis example:** + +```python +from collections import Counter + +# Analyze decoding efficiency +step_counts = Counter(output.step_map) +for step in sorted(step_counts.keys()): + print(f"Step {step}: {step_counts[step]} tokens decoded") +``` + +This helps researchers: +- Measure average tokens decoded per step +- Evaluate unmasking strategy effectiveness +- Compare different DLLM configurations + +## Checklist + +1. [x] Pre-commit or other linting tools are used to fix the potential lint issues. +2. [x] The modification is covered by complete unit tests. If not, please add more unit tests to ensure the correctness. +3. [x] If the modification has a dependency on downstream projects of a newer version, this PR should be tested with all supported versions of downstream projects. +4. [x] The documentation has been modified accordingly, like docstring or example tutorials. + diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index 13bd3a6d6a..80eb50a99e 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -445,6 +445,7 @@ class Response: token_ids: (List[int]): the output token ids. logprobs: (List[Dict[int, float]]): the top logprobs for each output position. + step_map: (List[int]): the decoding step for each token in DLLM mode. index (int): it refers to the position index of the input request batch """ @@ -456,6 +457,7 @@ class Response: logprobs: List[Dict[int, float]] = None logits: torch.Tensor = None last_hidden_state: torch.Tensor = None + step_map: List[int] = None index: int = 0 def __repr__(self): @@ -464,7 +466,7 @@ def __repr__(self): 'last_hidden_state=None' if self.last_hidden_state is None else f'last_hidden_state.shape={self.last_hidden_state.shape}\nlast_hidden_state={self.last_hidden_state}') s = (f'text={self.text}\ngenerate_token_len={self.generate_token_len}\nfinish_reason="{self.finish_reason}"\n' - f'token_ids={self.token_ids}\nlog_probs={self.logprobs}\n{logits}\n{hidden_state}') + f'token_ids={self.token_ids}\nlog_probs={self.logprobs}\nstep_map={self.step_map}\n{logits}\n{hidden_state}') return s @@ -537,6 +539,7 @@ class EngineOutput: cache_block_ids (List[int]): send cache blocks back for migration in Disaggregated LLM Serving when Prefill Engine is Done. req_metrics (RequestMetrics): request metrics information + step_map (List[int]): the decoding step for each token in DLLM mode """ status: ResponseType token_ids: List[int] @@ -546,6 +549,7 @@ class EngineOutput: last_hidden_state: torch.Tensor = None cache_block_ids: Optional[List[int]] = None req_metrics: Optional[RequestMetrics] = None + step_map: Optional[List[int]] = None @dataclass diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 923e3799f8..5b6cd8837d 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -55,6 +55,9 @@ class InferOutput: # for logging req_metrics: RequestMetrics = None + + # step map for DLLM: track which step each token was decoded + step_map: List[int] = None def _tensorlize_block_offsets(block_offsets, dtype=torch.int32): @@ -551,7 +554,7 @@ def _on_end_session(self, reqs: List[Request], **kwargs): if len(msgs) > 0 and msgs[0].preserve_cache: self.scheduler._set_message_status(msgs[0], MessageStatus.TO_BE_MIGRATED) else: - self.end_session(session_id) + self.scheduler.end_session(session_id) resp_type = ResponseType.SUCCESS if resp: self._response(req.resp, resp_type) @@ -842,6 +845,7 @@ def _make_infer_outputs( # generate output outputs: Dict[int, InferOutput] = dict() for idx, msg in enumerate(running): + # print(f"{idx}: {msg}") if not is_run[idx]: continue token_ids = msg.generated_ids @@ -864,6 +868,11 @@ def _make_infer_outputs( if num_logprobs >= 0: cur_logprobs = (logprobs.vals[idx][:num_logprobs + 1], logprobs.indices[idx][:num_logprobs + 1]) + # step_map: 获取每个 token 被解码的步数 + step_map = None + if hasattr(msg, 'generated_step_map'): + step_map = msg.generated_step_map.tolist() + req_metrics = RequestMetrics(new_token_timestamp, msg.engine_events) out = InferOutput(session_id=session_id, resp=msg.resp, @@ -871,7 +880,8 @@ def _make_infer_outputs( token_ids=token_ids, cache_block_ids=cache_block_ids, req_metrics=req_metrics, - logprobs=cur_logprobs) + logprobs=cur_logprobs, + step_map=step_map) outputs[session_id] = out if msg.return_logits: @@ -916,7 +926,6 @@ def __need_logits(seqs: SeqList): stopping_criteria = self.model_agent_strategy.make_stopping_criteria(running) sync_long_context = inputs.input_ids.numel() > self.cache_config.max_prefill_token_num - return dict( running=running, inputs=inputs, @@ -964,7 +973,8 @@ def __send_resp(out: InferOutput): logits=out.logits, cache_block_ids=out.cache_block_ids, req_metrics=out.req_metrics, - logprobs=logprobs)) + logprobs=logprobs, + step_map=out.step_map)) def __update_logprobs(step_outputs: List[InferOutput]): for out in step_outputs: @@ -1262,7 +1272,6 @@ def start_loop(self): def end_session(self, session_id: int): """End session.""" if session_id in self.scheduler.sessions: - self.sampling_strategy.on_session_end(session_id) self.scheduler.end_session(session_id) return True return False diff --git a/lmdeploy/pytorch/engine/engine_instance.py b/lmdeploy/pytorch/engine/engine_instance.py index e62d59ca6b..7dff553d8d 100644 --- a/lmdeploy/pytorch/engine/engine_instance.py +++ b/lmdeploy/pytorch/engine/engine_instance.py @@ -151,7 +151,9 @@ async def async_stream_infer(self, cache_block_ids = resp.data.get('cache_block_ids', None) if resp.data else None req_metrics = resp.data.get('req_metrics', None) if resp.data else None - logprobs = resp.data.pop('logprobs', None) if resp.data else None + logprobs = resp.data.get('logprobs', None) if resp.data else None + step_map = resp.data.get('step_map', None) if resp.data else None + if resp.type == ResponseType.SUCCESS: token_ids = resp.data['token_ids'].tolist() num_ids = len(token_ids) - output_offset @@ -161,8 +163,8 @@ async def async_stream_infer(self, num_ids, cache_block_ids=cache_block_ids, req_metrics=req_metrics, - logprobs=logprobs) - output_offset = len(token_ids) + logprobs=logprobs, + step_map=step_map) elif resp.type == ResponseType.FINISH: resp_data = resp.data token_ids = resp_data['token_ids'].tolist() @@ -175,7 +177,8 @@ async def async_stream_infer(self, logits=logits, cache_block_ids=cache_block_ids, req_metrics=req_metrics, - logprobs=logprobs) + logprobs=logprobs, + step_map=step_map) break else: logger.debug(f'session[{session_id}] failed.') diff --git a/lmdeploy/pytorch/strategies/dllm/sequence.py b/lmdeploy/pytorch/strategies/dllm/sequence.py index ab004a2b63..a02f85812e 100644 --- a/lmdeploy/pytorch/strategies/dllm/sequence.py +++ b/lmdeploy/pytorch/strategies/dllm/sequence.py @@ -34,12 +34,14 @@ class SchedulerSequenceDLLM(SchedulerSequenceDefault): # For dllm history_dllm_mask: HistoryDLLMMask = field(default_factory=HistoryDLLMMask) + history_step_map: HistoryTokenIds = field(default_factory=HistoryTokenIds) # 添加 step_map 追踪 def __post_init__(self): """Post init.""" super().__post_init__() self._num_valid_ids: int = len(self.history_cache) self._strategy: DLLMSequenceStrategy = self._seq_meta.strategy + self._current_step: int = 0 # 当前解码步数 @property def dllm_mask(self): @@ -47,6 +49,12 @@ def dllm_mask(self): end = start + self._num_token_ids return self.history_dllm_mask._token_ids[start:end] + @property + def step_map(self): + start = self.num_history_ids + end = start + self._num_token_ids + return self.history_step_map._token_ids[start:end] + @property def num_valid_ids(self): return self._num_valid_ids @@ -57,6 +65,12 @@ def generated_ids(self) -> np.ndarray: start = end - self.num_new_tokens return self.history_cache._token_ids[start:end] + @property + def generated_step_map(self) -> np.ndarray: + end = self.num_valid_ids + start = end - self.num_new_tokens + return self.history_step_map._token_ids[start:end] + @property def all_dllm_mask(self): return self.history_dllm_mask._token_ids[:self.num_all_ids] @@ -82,6 +96,8 @@ def _update_token_ids_inputs(self, token_ids: np.ndarray, dllm_mask: np.ndarray) dllm_mask_token = self.dllm_mask_token new_token_ids = [token_ids] new_dllm_mask = [dllm_mask] + # 输入阶段标记为步骤 0(prefill) + new_step_map = [np.zeros(len(token_ids), dtype=np.int32)] # add uncached tokens in token_ids # for example, [cccc cccc uumm], the [uu] in last block is remain valid. @@ -89,10 +105,13 @@ def _update_token_ids_inputs(self, token_ids: np.ndarray, dllm_mask: np.ndarray) if num_remain_valid != 0: prev_token_ids = self.valid_ids[-num_remain_valid:] prev_dllm_mask = np.full_like(prev_token_ids, DLLM_UNMASKED, dtype=DLLM_MASK_DTYPE) + prev_step_map = self.history_step_map._token_ids[self.num_history_ids:self.num_history_ids + num_remain_valid] new_token_ids = [prev_token_ids] + new_token_ids new_dllm_mask = [prev_dllm_mask] + new_dllm_mask + new_step_map = [prev_step_map] + new_step_map self.history_cache.resize(self.num_history_ids) self.history_dllm_mask.resize(self.num_history_ids) + self.history_step_map.resize(self.num_history_ids) num_tokens += num_remain_valid # pad to align with dllm_block_length @@ -100,20 +119,25 @@ def _update_token_ids_inputs(self, token_ids: np.ndarray, dllm_mask: np.ndarray) if num_pad > 0: pad_ids = np.full_like(token_ids, dllm_mask_token, shape=(num_pad, )) pad_mask = np.full_like(dllm_mask, DLLM_MASKED, shape=(num_pad, )) + pad_step = np.zeros(num_pad, dtype=np.int32) new_token_ids += [pad_ids] new_dllm_mask += [pad_mask] + new_step_map += [pad_step] token_ids = np.concatenate(new_token_ids) dllm_mask = np.concatenate(new_dllm_mask) + step_map = np.concatenate(new_step_map) assert len(token_ids) % dllm_block_length == 0 self.history_cache.append(token_ids) self.history_dllm_mask.append(dllm_mask) + self.history_step_map.append(step_map) self.output_start_pos = self._num_valid_ids + len(token_ids) self._num_valid_ids = self.num_history_ids + num_tokens self._num_token_ids = len(token_ids) self.num_new_tokens = 0 + self._current_step = 0 # 重置步数计数器 def _update_token_ids_decode(self, token_ids: np.ndarray, dllm_mask: np.ndarray): """Update token ids for decode.""" @@ -123,9 +147,25 @@ def _update_token_ids_decode(self, token_ids: np.ndarray, dllm_mask: np.ndarray) assert num_tokens % dllm_block_length == 0 num_history_ids = self.num_history_ids + # 获取旧的 mask 和 step_map 来判断哪些 token 是新解码的 + old_mask = self.history_dllm_mask._token_ids[num_history_ids:num_history_ids + num_tokens].copy() + old_step_map = self.history_step_map._token_ids[num_history_ids:num_history_ids + num_tokens].copy() + + # 检查是否有新的 tokens 被 unmask + newly_unmasked = (old_mask == DLLM_MASKED) & (dllm_mask == DLLM_UNMASKED) + + # 只有当有新 tokens 被 unmask 时才递增步数 + if newly_unmasked.any(): + self._current_step += 1 + + # 更新 step_map:对于从 MASKED 变成 UNMASKED 的 token,设置为当前步数 + new_step_map = old_step_map.copy() + new_step_map[newly_unmasked] = self._current_step + token_ids[dllm_mask == DLLM_MASKED] = dllm_mask_token self.history_cache[num_history_ids:] = token_ids self.history_dllm_mask[num_history_ids:] = dllm_mask + self.history_step_map[num_history_ids:] = new_step_map # check if all blocks are cached last_mask = dllm_mask[-dllm_block_length:] @@ -141,8 +181,10 @@ def _update_token_ids_decode(self, token_ids: np.ndarray, dllm_mask: np.ndarray) # add new block new_token_ids = np.full_like(token_ids, dllm_mask_token, shape=(dllm_block_length, )) new_dllm_mask = np.full_like(dllm_mask, DLLM_MASKED, shape=(dllm_block_length, )) + new_step_map_block = np.zeros(dllm_block_length, dtype=np.int32) self.history_cache.append(new_token_ids) self.history_dllm_mask.append(new_dllm_mask) + self.history_step_map.append(new_step_map_block) self._num_history_ids += self._num_token_ids self._num_token_ids = dllm_block_length @@ -154,9 +196,13 @@ def _update_token_ids_prefill(self, token_ids: np.ndarray, dllm_mask: np.ndarray # fill input cache if self.num_token_ids > dllm_block_length: end = self.num_token_ids - dllm_block_length - self.history_dllm_mask[num_history_ids:end] = DLLM_CACHED + self.history_dllm_mask[num_history_ids:num_history_ids + end] = DLLM_CACHED + # prefill 阶段缓存的部分标记为 0(这些是输入,不在 generated_ids 中) + self.history_step_map[num_history_ids:num_history_ids + end] = 0 self._num_history_ids += end self._num_token_ids -= end + + # prefill 后开始 decode,保持 _current_step = 0,第一次 unmask 会变成 1 # decoding update self._update_token_ids_decode(token_ids, dllm_mask) diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py index eed7abe630..927f4f69d8 100644 --- a/lmdeploy/serve/async_engine.py +++ b/lmdeploy/serve/async_engine.py @@ -13,6 +13,7 @@ from threading import Thread from typing import Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Tuple, Union +import torch import tqdm from lmdeploy import Tokenizer @@ -97,6 +98,9 @@ class GenOut: # for disaggregation cache_block_ids: List[int] = None + + # for DLLM + step_map: List[int] = None def _gen_out_to_response(out: GenOut, index) -> Response: @@ -108,6 +112,7 @@ def _gen_out_to_response(out: GenOut, index) -> Response: logprobs=out.logprobs, last_hidden_state=out.last_hidden_state, logits=out.logits, + step_map=out.step_map, index=index) @@ -125,6 +130,9 @@ def _append_response(dst: Response, src: Response): if src.logprobs: dst.logprobs = dst.logprobs or [] dst.logprobs += src.logprobs + if src.step_map: + dst.step_map = dst.step_map or [] + dst.step_map += src.step_map return dst @@ -357,6 +365,7 @@ def close(self): self.free_insts = None self.instances.clear() self.engine.close() + torch._C._cuda_clearCublasWorkspaces() def __enter__(self): return self @@ -658,9 +667,11 @@ async def _get_prompt_input(self, # Change multimodal data to openai text messages, i.e., # [{'role': 'user', 'content': [{'type': 'text', 'text': 'hi'}]}] -> # [{'role': 'user', 'content': 'hi'] - # Also ensure all messages have 'content' field (set to None if missing, e.g., assistant with tool_calls) - if isinstance(prompt, list): - prompt = [_merge_message_content(msg) for msg in prompt] + if isinstance(prompt, list) and any(isinstance(msg['content'], list) for msg in prompt): + prompt = [ + msg if isinstance(msg['content'], str) else dict(role=msg['role'], content=msg['content'][0]['text']) + for msg in prompt + ] if do_preprocess: # use adapter's chat template if possible chat_template = self.chat_template @@ -864,6 +875,11 @@ def is_error(status): token_ids += outputs.token_ids gen_len = len(token_ids) - input_len + + # 提取本次增量的 step_map + step_map_increment = None + if hasattr(outputs, 'step_map') and outputs.step_map is not None: + step_map_increment = outputs.step_map[mask] ids_offset = state.ids_offset response, state = self.tokenizer.detokenize_incrementally( @@ -879,7 +895,9 @@ def is_error(status): gen_len, finish_reason, token_ids=res, - cache_block_ids=outputs.cache_block_ids) + cache_block_ids=outputs.cache_block_ids, + step_map=step_map_increment) + if outputs.logprobs is not None: out.logprobs = (outputs.logprobs[:-hit_stop_token] if hit_stop_token else outputs.logprobs) if outputs.last_hidden_state is not None: @@ -909,6 +927,10 @@ def is_error(status): logger.info(f'session {session_id} finished, reason ' f'"{finish_reason}", input_tokens ' f'{len(input_ids)}, output_tokens {gen_len}') + + # 最后的 finish 输出不需要返回 step_map(已经在流式输出中累加了) + final_step_map = None + yield GenOut(response, self.id2step[session_id], len(input_ids), @@ -918,6 +940,7 @@ def is_error(status): logprobs=logprobs, logits=logits, last_hidden_state=last_hidden_state, + step_map=final_step_map, cache_block_ids=outputs.cache_block_ids) # Update a session's sequence only when it is in finished status if outputs.status == ResponseType.FINISH: