Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class GenerationConfig:
around special tokens. The behavior of Fast tokenizers is to have
this to False. This is setup to True in slow tokenizers.
logprobs (int): Number of log probabilities to return per output token.
response_format (Dict): Generate responses according to given formatting.
response_format (Dict): Only pytorch backend support formatting
response. Examples:
{
"type": "json_schema",
Expand Down Expand Up @@ -445,6 +445,7 @@ class Response:
token_ids: (List[int]): the output token ids.
logprobs: (List[Dict[int, float]]): the top logprobs for each output
position.
step_map: (List[int]): the decoding step for each token in DLLM mode.
index (int): it refers to the position index of the input request
batch
"""
Expand All @@ -456,6 +457,7 @@ class Response:
logprobs: List[Dict[int, float]] = None
logits: torch.Tensor = None
last_hidden_state: torch.Tensor = None
step_map: List[int] = None
index: int = 0

def __repr__(self):
Expand All @@ -464,7 +466,7 @@ def __repr__(self):
'last_hidden_state=None' if self.last_hidden_state is None else
f'last_hidden_state.shape={self.last_hidden_state.shape}\nlast_hidden_state={self.last_hidden_state}')
s = (f'text={self.text}\ngenerate_token_len={self.generate_token_len}\nfinish_reason="{self.finish_reason}"\n'
f'token_ids={self.token_ids}\nlog_probs={self.logprobs}\n{logits}\n{hidden_state}')
f'token_ids={self.token_ids}\nlog_probs={self.logprobs}\nstep_map={self.step_map}\n{logits}\n{hidden_state}')
return s


Expand Down Expand Up @@ -537,6 +539,7 @@ class EngineOutput:
cache_block_ids (List[int]): send cache blocks back for migration in
Disaggregated LLM Serving when Prefill Engine is Done.
req_metrics (RequestMetrics): request metrics information
step_map (List[int]): the decoding step for each token in DLLM mode
"""
status: ResponseType
token_ids: List[int]
Expand All @@ -546,6 +549,7 @@ class EngineOutput:
last_hidden_state: torch.Tensor = None
cache_block_ids: Optional[List[int]] = None
req_metrics: Optional[RequestMetrics] = None
step_map: Optional[List[int]] = None


@dataclass
Expand Down
25 changes: 15 additions & 10 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ class InferOutput:

# for logging
req_metrics: RequestMetrics = None

# step map for DLLM: track which step each token was decoded
step_map: List[int] = None


def _tensorlize_block_offsets(block_offsets, dtype=torch.int32):
Expand Down Expand Up @@ -551,7 +554,7 @@ def _on_end_session(self, reqs: List[Request], **kwargs):
if len(msgs) > 0 and msgs[0].preserve_cache:
self.scheduler._set_message_status(msgs[0], MessageStatus.TO_BE_MIGRATED)
else:
self.end_session(session_id)
self.scheduler.end_session(session_id)
resp_type = ResponseType.SUCCESS
if resp:
self._response(req.resp, resp_type)
Expand Down Expand Up @@ -838,6 +841,7 @@ def _make_infer_outputs(
# generate output
outputs: Dict[int, InferOutput] = dict()
for idx, msg in enumerate(running):
# print(f"{idx}: {msg}")
if not is_run[idx]:
continue
token_ids = msg.generated_ids
Expand All @@ -860,14 +864,20 @@ def _make_infer_outputs(
if num_logprobs >= 0:
cur_logprobs = (logprobs.vals[idx, :num_logprobs + 1], logprobs.indices[idx, :num_logprobs + 1])

# step_map: 获取每个 token 被解码的步数
step_map = None
if hasattr(msg, 'generated_step_map'):
step_map = msg.generated_step_map.tolist()

req_metrics = RequestMetrics(new_token_timestamp, msg.engine_events)
out = InferOutput(session_id=session_id,
resp=msg.resp,
finish=finish,
token_ids=token_ids,
cache_block_ids=cache_block_ids,
req_metrics=req_metrics,
logprobs=cur_logprobs)
logprobs=cur_logprobs,
step_map=step_map)
outputs[session_id] = out

if msg.return_logits:
Expand Down Expand Up @@ -912,7 +922,6 @@ def __need_logits(seqs: SeqList):
stopping_criteria = self.model_agent_strategy.make_stopping_criteria(running)

sync_long_context = inputs.input_ids.numel() > self.cache_config.max_prefill_token_num

return dict(
running=running,
inputs=inputs,
Expand Down Expand Up @@ -962,13 +971,15 @@ def __send_resp(out: InferOutput):
cur_logprobs = dict(zip(indices, vals))
logprobs = [] if out.resp.data is None else out.resp.data.get('logprobs', [])
logprobs = logprobs + [cur_logprobs]

self._response(out.resp,
resp_type,
data=dict(token_ids=out.token_ids,
logits=out.logits,
cache_block_ids=out.cache_block_ids,
req_metrics=out.req_metrics,
logprobs=logprobs))
logprobs=logprobs,
step_map=out.step_map))

def __send_resps(step_outputs: List[InferOutput]):
"""Send response callback."""
Expand Down Expand Up @@ -1212,11 +1223,6 @@ async def async_loop(self):
self._loop_finally()

def close(self):
if self.executor.device_type == 'cuda':
# https://discuss.pytorch.org/t/how-to-delete-a-tensor-in-gpu-to-free-up-memory/48879/32
# W/O this, repeatedly rebuilding and destroying engines within the same process
# will cause more and more reserved CUDA memory.
torch._C._cuda_clearCublasWorkspaces()
if self._loop_main is not None:
self._loop_main.cancel()
else:
Expand All @@ -1243,7 +1249,6 @@ def start_loop(self):
def end_session(self, session_id: int):
"""End session."""
if session_id in self.scheduler.sessions:
self.sampling_strategy.on_session_end(session_id)
self.scheduler.end_session(session_id)
return True
return False
Expand Down
8 changes: 6 additions & 2 deletions lmdeploy/pytorch/engine/engine_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ async def async_stream_infer(self,
cache_block_ids = resp.data.get('cache_block_ids', None) if resp.data else None
req_metrics = resp.data.get('req_metrics', None) if resp.data else None
logprobs = resp.data.get('logprobs', None) if resp.data else None
step_map = resp.data.get('step_map', None) if resp.data else None

if resp.type == ResponseType.SUCCESS:
token_ids = resp.data['token_ids'].tolist()
num_ids = len(token_ids)
Expand All @@ -160,7 +162,8 @@ async def async_stream_infer(self,
num_ids,
cache_block_ids=cache_block_ids,
req_metrics=req_metrics,
logprobs=logprobs)
logprobs=logprobs,
step_map=step_map)
elif resp.type == ResponseType.FINISH:
resp_data = resp.data
token_ids = resp_data['token_ids'].tolist()
Expand All @@ -173,7 +176,8 @@ async def async_stream_infer(self,
logits=logits,
cache_block_ids=cache_block_ids,
req_metrics=req_metrics,
logprobs=logprobs)
logprobs=logprobs,
step_map=step_map)
break
else:
logger.debug(f'session[{session_id}] failed.')
Expand Down
48 changes: 47 additions & 1 deletion lmdeploy/pytorch/strategies/dllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,27 @@ class SchedulerSequenceDLLM(SchedulerSequenceDefault):

# For dllm
history_dllm_mask: HistoryDLLMMask = field(default_factory=HistoryDLLMMask)
history_step_map: HistoryTokenIds = field(default_factory=HistoryTokenIds) # 添加 step_map 追踪

def __post_init__(self):
"""Post init."""
super().__post_init__()
self._num_valid_ids: int = len(self.history_cache)
self._strategy: DLLMSequenceStrategy = self._seq_meta.strategy
self._current_step: int = 0 # 当前解码步数

@property
def dllm_mask(self):
start = self.num_history_ids
end = start + self._num_token_ids
return self.history_dllm_mask._token_ids[start:end]

@property
def step_map(self):
start = self.num_history_ids
end = start + self._num_token_ids
return self.history_step_map._token_ids[start:end]

@property
def num_valid_ids(self):
return self._num_valid_ids
Expand All @@ -57,6 +65,12 @@ def generated_ids(self) -> np.ndarray:
start = end - self.num_new_tokens
return self.history_cache._token_ids[start:end]

@property
def generated_step_map(self) -> np.ndarray:
end = self.num_valid_ids
start = end - self.num_new_tokens
return self.history_step_map._token_ids[start:end]

@property
def all_dllm_mask(self):
return self.history_dllm_mask._token_ids[:self.num_all_ids]
Expand All @@ -82,38 +96,48 @@ def _update_token_ids_inputs(self, token_ids: np.ndarray, dllm_mask: np.ndarray)
dllm_mask_token = self.dllm_mask_token
new_token_ids = [token_ids]
new_dllm_mask = [dllm_mask]
# 输入阶段标记为步骤 0(prefill)
new_step_map = [np.zeros(len(token_ids), dtype=np.int32)]

# add uncached tokens in token_ids
# for example, [cccc cccc uumm], the [uu] in last block is remain valid.
num_remain_valid = self.num_valid_ids - self.num_history_ids
if num_remain_valid != 0:
prev_token_ids = self.valid_ids[-num_remain_valid:]
prev_dllm_mask = np.full_like(prev_token_ids, DLLM_UNMASKED, dtype=DLLM_MASK_DTYPE)
prev_step_map = self.history_step_map._token_ids[self.num_history_ids:self.num_history_ids + num_remain_valid]
new_token_ids = [prev_token_ids] + new_token_ids
new_dllm_mask = [prev_dllm_mask] + new_dllm_mask
new_step_map = [prev_step_map] + new_step_map
self.history_cache.resize(self.num_history_ids)
self.history_dllm_mask.resize(self.num_history_ids)
self.history_step_map.resize(self.num_history_ids)
num_tokens += num_remain_valid

# pad to align with dllm_block_length
num_pad = (-num_tokens) % dllm_block_length
if num_pad > 0:
pad_ids = np.full_like(token_ids, dllm_mask_token, shape=(num_pad, ))
pad_mask = np.full_like(dllm_mask, DLLM_MASKED, shape=(num_pad, ))
pad_step = np.zeros(num_pad, dtype=np.int32)
new_token_ids += [pad_ids]
new_dllm_mask += [pad_mask]
new_step_map += [pad_step]

token_ids = np.concatenate(new_token_ids)
dllm_mask = np.concatenate(new_dllm_mask)
step_map = np.concatenate(new_step_map)

assert len(token_ids) % dllm_block_length == 0

self.history_cache.append(token_ids)
self.history_dllm_mask.append(dllm_mask)
self.history_step_map.append(step_map)
self.output_start_pos = self._num_valid_ids + len(token_ids)
self._num_valid_ids = self.num_history_ids + num_tokens
self._num_token_ids = len(token_ids)
self.num_new_tokens = 0
self._current_step = 0 # 重置步数计数器

def _update_token_ids_decode(self, token_ids: np.ndarray, dllm_mask: np.ndarray):
"""Update token ids for decode."""
Expand All @@ -123,9 +147,25 @@ def _update_token_ids_decode(self, token_ids: np.ndarray, dllm_mask: np.ndarray)
assert num_tokens % dllm_block_length == 0
num_history_ids = self.num_history_ids

# 获取旧的 mask 和 step_map 来判断哪些 token 是新解码的
old_mask = self.history_dllm_mask._token_ids[num_history_ids:num_history_ids + num_tokens].copy()
old_step_map = self.history_step_map._token_ids[num_history_ids:num_history_ids + num_tokens].copy()

# 检查是否有新的 tokens 被 unmask
newly_unmasked = (old_mask == DLLM_MASKED) & (dllm_mask == DLLM_UNMASKED)

# 只有当有新 tokens 被 unmask 时才递增步数
if newly_unmasked.any():
self._current_step += 1

# 更新 step_map:对于从 MASKED 变成 UNMASKED 的 token,设置为当前步数
new_step_map = old_step_map.copy()
new_step_map[newly_unmasked] = self._current_step

token_ids[dllm_mask == DLLM_MASKED] = dllm_mask_token
self.history_cache[num_history_ids:] = token_ids
self.history_dllm_mask[num_history_ids:] = dllm_mask
self.history_step_map[num_history_ids:] = new_step_map

# check if all blocks are cached
last_mask = dllm_mask[-dllm_block_length:]
Expand All @@ -141,8 +181,10 @@ def _update_token_ids_decode(self, token_ids: np.ndarray, dllm_mask: np.ndarray)
# add new block
new_token_ids = np.full_like(token_ids, dllm_mask_token, shape=(dllm_block_length, ))
new_dllm_mask = np.full_like(dllm_mask, DLLM_MASKED, shape=(dllm_block_length, ))
new_step_map_block = np.zeros(dllm_block_length, dtype=np.int32)
self.history_cache.append(new_token_ids)
self.history_dllm_mask.append(new_dllm_mask)
self.history_step_map.append(new_step_map_block)
self._num_history_ids += self._num_token_ids
self._num_token_ids = dllm_block_length

Expand All @@ -154,9 +196,13 @@ def _update_token_ids_prefill(self, token_ids: np.ndarray, dllm_mask: np.ndarray
# fill input cache
if self.num_token_ids > dllm_block_length:
end = self.num_token_ids - dllm_block_length
self.history_dllm_mask[num_history_ids:end] = DLLM_CACHED
self.history_dllm_mask[num_history_ids:num_history_ids + end] = DLLM_CACHED
# prefill 阶段缓存的部分标记为 0(这些是输入,不在 generated_ids 中)
self.history_step_map[num_history_ids:num_history_ids + end] = 0
self._num_history_ids += end
self._num_token_ids -= end

# prefill 后开始 decode,保持 _current_step = 0,第一次 unmask 会变成 1

# decoding update
self._update_token_ids_decode(token_ids, dllm_mask)
Expand Down
Loading
Loading