Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions PR_DESCRIPTION.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
## Motivation

In DLLM (Disaggregated LLM) mode, tokens are generated in blocks and progressively unmasked in a non-sequential order. Currently, there is no way to track which decoding step each token was revealed in. This information is valuable for:
- Analyzing DLLM decoding efficiency
- Understanding the non-sequential token generation pattern
- Debugging and optimizing unmasking strategies
- Research on speculative decoding performance

This PR adds a `step_map` field to track the decoding step number for each generated token.

## Modification

This PR introduces a `step_map` feature that records which step each token was decoded in DLLM mode:

1. **Core tracking logic** (`lmdeploy/pytorch/strategies/dllm/sequence.py`):
- Added `history_step_map` field to `SchedulerSequenceDLLM` to store step numbers
- Added `_current_step` counter to track decoding steps
- Added `step_map` and `generated_step_map` properties
- Updated `_update_token_ids_decode()` to record step numbers when tokens transition from MASKED to UNMASKED
- Step counter only increments when new tokens are actually unmasked

2. **Engine layer** (`lmdeploy/pytorch/engine/engine.py`):
- Added `step_map` field to `InferOutput` dataclass
- Extract `step_map` from messages in `_make_infer_outputs()`
- Propagate `step_map` through response data

3. **Instance layer** (`lmdeploy/pytorch/engine/engine_instance.py`):
- Extract and pass `step_map` to `EngineOutput`

4. **API layer** (`lmdeploy/messages.py`):
- Added `step_map` field to `Response` dataclass
- Added `step_map` field to `EngineOutput` dataclass
- Updated `Response.__repr__()` to display step_map

5. **Async engine layer** (`lmdeploy/serve/async_engine.py`):
- Added `step_map` field to `GenOut` dataclass
- Updated `_gen_out_to_response()` to pass step_map
- Updated `_append_response()` to accumulate step_map across iterations
- Extract incremental step_map from engine outputs in generation loop

**How it works:**
- Each token gets a step number indicating when it was unmasked (1, 2, 3, ...)
- The step_map array has the same length as the generated tokens
- Non-sequential order in step_map reflects DLLM's parallel decoding behavior

## BC-breaking (Optional)

**No breaking changes.** This is a backward-compatible addition:
- New optional `step_map` field defaults to `None` in all dataclasses
- Existing code will continue to work without modification
- Only DLLM mode populates step_map; other modes return `None`

## Use cases (Optional)

**Example usage:**

```python
from lmdeploy import pipeline, PytorchEngineConfig, GenerationConfig

# Configure DLLM
backend_config = PytorchEngineConfig(
dllm_block_length=4,
dllm_unmasking_strategy="low_confidence_dynamic",
)

with pipeline(model_path, backend_config=backend_config) as pipe:
gen_config = GenerationConfig(max_new_tokens=100)
outputs = pipe(["Hello"], gen_config=gen_config)

for output in outputs:
if output.step_map is not None:
print(f"Tokens: {output.token_ids}")
print(f"Step map: {output.step_map}")
# Example: [1, 2, 1, 1, 3, 3, 3, 3, ...]
# Shows non-sequential unmasking pattern
```

**Analysis example:**

```python
from collections import Counter

# Analyze decoding efficiency
step_counts = Counter(output.step_map)
for step in sorted(step_counts.keys()):
print(f"Step {step}: {step_counts[step]} tokens decoded")
```

This helps researchers:
- Measure average tokens decoded per step
- Evaluate unmasking strategy effectiveness
- Compare different DLLM configurations

## Checklist

1. [x] Pre-commit or other linting tools are used to fix the potential lint issues.
2. [x] The modification is covered by complete unit tests. If not, please add more unit tests to ensure the correctness.
3. [x] If the modification has a dependency on downstream projects of a newer version, this PR should be tested with all supported versions of downstream projects.
4. [x] The documentation has been modified accordingly, like docstring or example tutorials.

6 changes: 5 additions & 1 deletion lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,7 @@ class Response:
token_ids: (List[int]): the output token ids.
logprobs: (List[Dict[int, float]]): the top logprobs for each output
position.
step_map: (List[int]): the decoding step for each token in DLLM mode.
index (int): it refers to the position index of the input request
batch
"""
Expand All @@ -456,6 +457,7 @@ class Response:
logprobs: List[Dict[int, float]] = None
logits: torch.Tensor = None
last_hidden_state: torch.Tensor = None
step_map: List[int] = None
index: int = 0

def __repr__(self):
Expand All @@ -464,7 +466,7 @@ def __repr__(self):
'last_hidden_state=None' if self.last_hidden_state is None else
f'last_hidden_state.shape={self.last_hidden_state.shape}\nlast_hidden_state={self.last_hidden_state}')
s = (f'text={self.text}\ngenerate_token_len={self.generate_token_len}\nfinish_reason="{self.finish_reason}"\n'
f'token_ids={self.token_ids}\nlog_probs={self.logprobs}\n{logits}\n{hidden_state}')
f'token_ids={self.token_ids}\nlog_probs={self.logprobs}\nstep_map={self.step_map}\n{logits}\n{hidden_state}')
return s


Expand Down Expand Up @@ -537,6 +539,7 @@ class EngineOutput:
cache_block_ids (List[int]): send cache blocks back for migration in
Disaggregated LLM Serving when Prefill Engine is Done.
req_metrics (RequestMetrics): request metrics information
step_map (List[int]): the decoding step for each token in DLLM mode
"""
status: ResponseType
token_ids: List[int]
Expand All @@ -546,6 +549,7 @@ class EngineOutput:
last_hidden_state: torch.Tensor = None
cache_block_ids: Optional[List[int]] = None
req_metrics: Optional[RequestMetrics] = None
step_map: Optional[List[int]] = None


@dataclass
Expand Down
19 changes: 14 additions & 5 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ class InferOutput:

# for logging
req_metrics: RequestMetrics = None

# step map for DLLM: track which step each token was decoded
step_map: List[int] = None


def _tensorlize_block_offsets(block_offsets, dtype=torch.int32):
Expand Down Expand Up @@ -551,7 +554,7 @@ def _on_end_session(self, reqs: List[Request], **kwargs):
if len(msgs) > 0 and msgs[0].preserve_cache:
self.scheduler._set_message_status(msgs[0], MessageStatus.TO_BE_MIGRATED)
else:
self.end_session(session_id)
self.scheduler.end_session(session_id)
resp_type = ResponseType.SUCCESS
if resp:
self._response(req.resp, resp_type)
Expand Down Expand Up @@ -842,6 +845,7 @@ def _make_infer_outputs(
# generate output
outputs: Dict[int, InferOutput] = dict()
for idx, msg in enumerate(running):
# print(f"{idx}: {msg}")
if not is_run[idx]:
continue
token_ids = msg.generated_ids
Expand All @@ -864,14 +868,20 @@ def _make_infer_outputs(
if num_logprobs >= 0:
cur_logprobs = (logprobs.vals[idx][:num_logprobs + 1], logprobs.indices[idx][:num_logprobs + 1])

# step_map: 获取每个 token 被解码的步数
step_map = None
if hasattr(msg, 'generated_step_map'):
step_map = msg.generated_step_map.tolist()

req_metrics = RequestMetrics(new_token_timestamp, msg.engine_events)
out = InferOutput(session_id=session_id,
resp=msg.resp,
finish=finish,
token_ids=token_ids,
cache_block_ids=cache_block_ids,
req_metrics=req_metrics,
logprobs=cur_logprobs)
logprobs=cur_logprobs,
step_map=step_map)
outputs[session_id] = out

if msg.return_logits:
Expand Down Expand Up @@ -916,7 +926,6 @@ def __need_logits(seqs: SeqList):
stopping_criteria = self.model_agent_strategy.make_stopping_criteria(running)

sync_long_context = inputs.input_ids.numel() > self.cache_config.max_prefill_token_num

return dict(
running=running,
inputs=inputs,
Expand Down Expand Up @@ -964,7 +973,8 @@ def __send_resp(out: InferOutput):
logits=out.logits,
cache_block_ids=out.cache_block_ids,
req_metrics=out.req_metrics,
logprobs=logprobs))
logprobs=logprobs,
step_map=out.step_map))

def __update_logprobs(step_outputs: List[InferOutput]):
for out in step_outputs:
Expand Down Expand Up @@ -1262,7 +1272,6 @@ def start_loop(self):
def end_session(self, session_id: int):
"""End session."""
if session_id in self.scheduler.sessions:
self.sampling_strategy.on_session_end(session_id)
self.scheduler.end_session(session_id)
return True
return False
Expand Down
11 changes: 7 additions & 4 deletions lmdeploy/pytorch/engine/engine_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ async def async_stream_infer(self,

cache_block_ids = resp.data.get('cache_block_ids', None) if resp.data else None
req_metrics = resp.data.get('req_metrics', None) if resp.data else None
logprobs = resp.data.pop('logprobs', None) if resp.data else None
logprobs = resp.data.get('logprobs', None) if resp.data else None
step_map = resp.data.get('step_map', None) if resp.data else None

if resp.type == ResponseType.SUCCESS:
token_ids = resp.data['token_ids'].tolist()
num_ids = len(token_ids) - output_offset
Expand All @@ -161,8 +163,8 @@ async def async_stream_infer(self,
num_ids,
cache_block_ids=cache_block_ids,
req_metrics=req_metrics,
logprobs=logprobs)
output_offset = len(token_ids)
logprobs=logprobs,
step_map=step_map)
elif resp.type == ResponseType.FINISH:
resp_data = resp.data
token_ids = resp_data['token_ids'].tolist()
Expand All @@ -175,7 +177,8 @@ async def async_stream_infer(self,
logits=logits,
cache_block_ids=cache_block_ids,
req_metrics=req_metrics,
logprobs=logprobs)
logprobs=logprobs,
step_map=step_map)
break
else:
logger.debug(f'session[{session_id}] failed.')
Expand Down
48 changes: 47 additions & 1 deletion lmdeploy/pytorch/strategies/dllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,27 @@ class SchedulerSequenceDLLM(SchedulerSequenceDefault):

# For dllm
history_dllm_mask: HistoryDLLMMask = field(default_factory=HistoryDLLMMask)
history_step_map: HistoryTokenIds = field(default_factory=HistoryTokenIds) # 添加 step_map 追踪

def __post_init__(self):
"""Post init."""
super().__post_init__()
self._num_valid_ids: int = len(self.history_cache)
self._strategy: DLLMSequenceStrategy = self._seq_meta.strategy
self._current_step: int = 0 # 当前解码步数

@property
def dllm_mask(self):
start = self.num_history_ids
end = start + self._num_token_ids
return self.history_dllm_mask._token_ids[start:end]

@property
def step_map(self):
start = self.num_history_ids
end = start + self._num_token_ids
return self.history_step_map._token_ids[start:end]

@property
def num_valid_ids(self):
return self._num_valid_ids
Expand All @@ -57,6 +65,12 @@ def generated_ids(self) -> np.ndarray:
start = end - self.num_new_tokens
return self.history_cache._token_ids[start:end]

@property
def generated_step_map(self) -> np.ndarray:
end = self.num_valid_ids
start = end - self.num_new_tokens
return self.history_step_map._token_ids[start:end]

@property
def all_dllm_mask(self):
return self.history_dllm_mask._token_ids[:self.num_all_ids]
Expand All @@ -82,38 +96,48 @@ def _update_token_ids_inputs(self, token_ids: np.ndarray, dllm_mask: np.ndarray)
dllm_mask_token = self.dllm_mask_token
new_token_ids = [token_ids]
new_dllm_mask = [dllm_mask]
# 输入阶段标记为步骤 0(prefill)
new_step_map = [np.zeros(len(token_ids), dtype=np.int32)]

# add uncached tokens in token_ids
# for example, [cccc cccc uumm], the [uu] in last block is remain valid.
num_remain_valid = self.num_valid_ids - self.num_history_ids
if num_remain_valid != 0:
prev_token_ids = self.valid_ids[-num_remain_valid:]
prev_dllm_mask = np.full_like(prev_token_ids, DLLM_UNMASKED, dtype=DLLM_MASK_DTYPE)
prev_step_map = self.history_step_map._token_ids[self.num_history_ids:self.num_history_ids + num_remain_valid]
new_token_ids = [prev_token_ids] + new_token_ids
new_dllm_mask = [prev_dllm_mask] + new_dllm_mask
new_step_map = [prev_step_map] + new_step_map
self.history_cache.resize(self.num_history_ids)
self.history_dllm_mask.resize(self.num_history_ids)
self.history_step_map.resize(self.num_history_ids)
num_tokens += num_remain_valid

# pad to align with dllm_block_length
num_pad = (-num_tokens) % dllm_block_length
if num_pad > 0:
pad_ids = np.full_like(token_ids, dllm_mask_token, shape=(num_pad, ))
pad_mask = np.full_like(dllm_mask, DLLM_MASKED, shape=(num_pad, ))
pad_step = np.zeros(num_pad, dtype=np.int32)
new_token_ids += [pad_ids]
new_dllm_mask += [pad_mask]
new_step_map += [pad_step]

token_ids = np.concatenate(new_token_ids)
dllm_mask = np.concatenate(new_dllm_mask)
step_map = np.concatenate(new_step_map)

assert len(token_ids) % dllm_block_length == 0

self.history_cache.append(token_ids)
self.history_dllm_mask.append(dllm_mask)
self.history_step_map.append(step_map)
self.output_start_pos = self._num_valid_ids + len(token_ids)
self._num_valid_ids = self.num_history_ids + num_tokens
self._num_token_ids = len(token_ids)
self.num_new_tokens = 0
self._current_step = 0 # 重置步数计数器

def _update_token_ids_decode(self, token_ids: np.ndarray, dllm_mask: np.ndarray):
"""Update token ids for decode."""
Expand All @@ -123,9 +147,25 @@ def _update_token_ids_decode(self, token_ids: np.ndarray, dllm_mask: np.ndarray)
assert num_tokens % dllm_block_length == 0
num_history_ids = self.num_history_ids

# 获取旧的 mask 和 step_map 来判断哪些 token 是新解码的
old_mask = self.history_dllm_mask._token_ids[num_history_ids:num_history_ids + num_tokens].copy()
old_step_map = self.history_step_map._token_ids[num_history_ids:num_history_ids + num_tokens].copy()

# 检查是否有新的 tokens 被 unmask
newly_unmasked = (old_mask == DLLM_MASKED) & (dllm_mask == DLLM_UNMASKED)

# 只有当有新 tokens 被 unmask 时才递增步数
if newly_unmasked.any():
self._current_step += 1

# 更新 step_map:对于从 MASKED 变成 UNMASKED 的 token,设置为当前步数
new_step_map = old_step_map.copy()
new_step_map[newly_unmasked] = self._current_step

token_ids[dllm_mask == DLLM_MASKED] = dllm_mask_token
self.history_cache[num_history_ids:] = token_ids
self.history_dllm_mask[num_history_ids:] = dllm_mask
self.history_step_map[num_history_ids:] = new_step_map

# check if all blocks are cached
last_mask = dllm_mask[-dllm_block_length:]
Expand All @@ -141,8 +181,10 @@ def _update_token_ids_decode(self, token_ids: np.ndarray, dllm_mask: np.ndarray)
# add new block
new_token_ids = np.full_like(token_ids, dllm_mask_token, shape=(dllm_block_length, ))
new_dllm_mask = np.full_like(dllm_mask, DLLM_MASKED, shape=(dllm_block_length, ))
new_step_map_block = np.zeros(dllm_block_length, dtype=np.int32)
self.history_cache.append(new_token_ids)
self.history_dllm_mask.append(new_dllm_mask)
self.history_step_map.append(new_step_map_block)
self._num_history_ids += self._num_token_ids
self._num_token_ids = dllm_block_length

Expand All @@ -154,9 +196,13 @@ def _update_token_ids_prefill(self, token_ids: np.ndarray, dllm_mask: np.ndarray
# fill input cache
if self.num_token_ids > dllm_block_length:
end = self.num_token_ids - dllm_block_length
self.history_dllm_mask[num_history_ids:end] = DLLM_CACHED
self.history_dllm_mask[num_history_ids:num_history_ids + end] = DLLM_CACHED
# prefill 阶段缓存的部分标记为 0(这些是输入,不在 generated_ids 中)
self.history_step_map[num_history_ids:num_history_ids + end] = 0
self._num_history_ids += end
self._num_token_ids -= end

# prefill 后开始 decode,保持 _current_step = 0,第一次 unmask 会变成 1

# decoding update
self._update_token_ids_decode(token_ids, dllm_mask)
Expand Down
Loading