Skip to content

Commit 4cd8c85

Browse files
committed
do not remove rejected tokens in prefill to remove sync op
1 parent d72f284 commit 4cd8c85

File tree

5 files changed

+43
-78
lines changed

5 files changed

+43
-78
lines changed

lmdeploy/pytorch/engine/model_agent.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,8 @@ async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_ou
391391
await asyncio.sleep(0)
392392
return output
393393

394-
async def _async_model_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap):
394+
async def _async_model_forward(self, inputs: ModelInputs, spec_inputs: SpecDecodeInputs, swap_in_map: SwapMap,
395+
swap_out_map: SwapMap):
395396
"""Model forward.
396397
397398
Args:
@@ -433,11 +434,11 @@ async def __long_context_single_forward(new_inputs):
433434
outputs = await __forward(inputs)
434435

435436
loop_count = self.num_spec_tokens - 1
436-
draft_token_ids, model_metas, target_hidden_states = self.proposer.get_outputs(outputs, inputs)
437+
draft_token_ids, model_metas, target_hidden_states = self.proposer.get_outputs(outputs, inputs, spec_inputs)
437438
draft_tokens_li = [draft_token_ids]
438439
if loop_count > 0:
439-
inputs = self.proposer.update_inputs_decoding(inputs, draft_token_ids.transpose(0, 1), target_hidden_states,
440-
model_metas)
440+
inputs = self.proposer.update_inputs_decoding(inputs, spec_inputs, draft_token_ids.transpose(0, 1),
441+
target_hidden_states, model_metas)
441442
for loop_idx in range(loop_count):
442443
outputs = await self.async_forward(inputs, swap_in_map=dict(), swap_out_map=dict())
443444
draft_token_ids, model_metas, target_hidden_states = self.proposer.get_outputs(outputs, inputs)
@@ -466,14 +467,17 @@ async def async_model_forward(self,
466467
spec_inputs.num_rejected_tokens = num_rejected_tokens
467468
spec_inputs.reject_sample_tokens = output_token_ids
468469
spec_inputs.next_token_ids = last_token_ids
470+
spec_inputs.last_token_indices = model_inputs.seq_length.cumsum(0) - 1 - num_rejected_tokens
469471
else:
470472
spec_inputs.next_token_ids = spec_inputs.bonus_token_ids
471473
output_token_ids = spec_inputs.next_token_ids.unsqueeze(-1)
474+
spec_inputs.last_token_indices = model_inputs.seq_length.cumsum(0) - 1
472475

473476
with record_function('draft_prepare_inputs'):
474477
draft_model_inputs = self.proposer.prepare_inputs(model_inputs, spec_inputs)
475478

476479
new_draft_tokens = await self._async_model_forward(draft_model_inputs,
480+
spec_inputs,
477481
swap_in_map=swap_in_map,
478482
swap_out_map=swap_out_map)
479483
outputs = dict(output_token_ids=output_token_ids, spec_token_ids=new_draft_tokens)

lmdeploy/pytorch/model_inputs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ class SpecDecodeInputs:
132132
next_token_ids: torch.LongTensor = None
133133
num_rejected_tokens: torch.LongTensor = None
134134
reject_sample_tokens: torch.LongTensor = None
135+
last_token_indices: torch.LongTensor = None
135136

136137

137138
@dataclass

lmdeploy/pytorch/spec_decode/base.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,10 @@ def build_model(self, empty_init: bool, target_model: torch.nn.Module = None):
8080
self.model = patched_model
8181
self.target_model = target_model
8282

83-
def get_outputs(self, model_outputs: Dict[str, torch.Tensor], model_inputs: ModelInputs):
83+
def get_outputs(self,
84+
model_outputs: Dict[str, torch.Tensor],
85+
model_inputs: ModelInputs,
86+
spec_inputs: SpecDecodeInputs = None):
8487
"""Get outputs."""
8588
raise NotImplementedError()
8689

@@ -97,20 +100,24 @@ def _forward(self, model_inputs: ModelInputs, cache_engine: CacheEngine = None,
97100
cache_engine=cache_engine,
98101
stream=stream)
99102

100-
def update_inputs_decoding(self, model_inputs: ModelInputs, input_ids: torch.Tensor,
101-
target_hidden_states: torch.Tensor, model_metas: List[Any]):
103+
def update_inputs_decoding(self, model_inputs: ModelInputs, spec_inputs: SpecDecodeInputs,
104+
next_input_ids: torch.Tensor, target_hidden_states: torch.Tensor,
105+
model_metas: List[Any]):
102106
"""Update to decoding inputs."""
103107
model_inputs.is_decoding = True
104108
batch_size = model_inputs.seq_length.size(0)
105-
model_inputs.input_ids = input_ids
109+
model_inputs.input_ids = next_input_ids
106110
model_inputs.max_q_seqlen = 1
107111
model_inputs.max_kv_seqlen += 1
108112
model_inputs.sum_kv_seqlen += model_inputs.seq_length.numel()
109113
model_inputs.history_lengths += model_inputs.seq_length
114+
if spec_inputs.num_rejected_tokens is not None:
115+
model_inputs.history_lengths -= spec_inputs.num_rejected_tokens
110116
model_inputs.seq_length = model_inputs.seq_length.new_ones(batch_size)
111117
model_inputs.target_position_ids = model_inputs.history_lengths.unsqueeze(0).clone()
112118
model_inputs.model_metas = model_metas
113119
model_inputs.target_hidden_states = target_hidden_states
120+
model_inputs.spec_metadata = None
114121
return model_inputs
115122

116123
@record_function('draft_get_logits')

lmdeploy/pytorch/spec_decode/deepseek_mtp.py

Lines changed: 15 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
from typing import Dict
33

44
import torch
5-
import triton
6-
import triton.language as tl
75

86
from lmdeploy.utils import get_logger
97

@@ -16,54 +14,35 @@
1614
@SPEC_PROPOSERS.register_module(name='deepseek_mtp')
1715
class DeepseekMTP(BaseSpecProposer):
1816

19-
def get_outputs(self, model_outputs: Dict[str, torch.Tensor], model_inputs: ModelInputs):
17+
def get_outputs(self,
18+
model_outputs: Dict[str, torch.Tensor],
19+
model_inputs: ModelInputs,
20+
spec_inputs: SpecDecodeInputs = None):
2021
"""Get outputs."""
2122
hidden_states = model_outputs['hidden_states']
2223
model_metas = model_outputs['model_metas']
2324
if not model_inputs.is_decoding:
24-
if model_inputs.seq_length.size(0) == 1:
25+
assert spec_inputs is not None, 'spec_inputs should be provided for prefill mode'
26+
if model_inputs.seq_length.size(0) == 1 and spec_inputs.num_rejected_tokens is None:
2527
hidden_states = hidden_states[:, -1:]
2628
else:
27-
last_token_loc = model_inputs.seq_length.cumsum(0) - 1
29+
last_token_loc = spec_inputs.last_token_indices
2830
hidden_states = hidden_states[:, last_token_loc]
31+
2932
logits = self.get_logits(hidden_states)[0]
3033
draft_token_ids = logits.argmax(dim=-1, keepdim=True)
3134
return draft_token_ids, model_metas, hidden_states
3235

3336
def prepare_inputs(self, model_inputs: ModelInputs, spec_inputs: SpecDecodeInputs):
3437
"""Prepare inputs."""
3538
spec_metadata = model_inputs.spec_metadata
36-
37-
if spec_metadata.draft_token_ids is None:
38-
input_ids = model_inputs.input_ids
39-
seq_length = model_inputs.seq_length
40-
else:
41-
# select input ids
42-
query_lens = model_inputs.seq_length
43-
batch_size = query_lens.size(0)
44-
cum_query_lens = query_lens.new_zeros((batch_size + 1), dtype=torch.long)
45-
cum_qery_lens_new = query_lens.new_zeros((batch_size + 1), dtype=torch.long)
46-
torch.cumsum(query_lens, dim=0, out=cum_query_lens[1:])
47-
query_lens_new = query_lens - spec_inputs.num_rejected_tokens
48-
torch.cumsum(query_lens_new, dim=0, out=cum_qery_lens_new[1:])
49-
keep_token_indices = query_lens.new_zeros(
50-
model_inputs.input_ids.size(1) - spec_inputs.num_rejected_tokens.sum())
51-
cal_token_indices[(batch_size, )](keep_token_indices, cum_query_lens, cum_qery_lens_new, BLOCK_SIZE=1024)
52-
input_ids = model_inputs.input_ids[:, keep_token_indices]
53-
seq_length = query_lens_new
54-
55-
spec_inputs.target_hidden_states = spec_inputs.target_hidden_states[:, keep_token_indices]
56-
if spec_inputs.target_position_ids is not None:
57-
spec_inputs.target_position_ids = spec_inputs.target_position_ids[:, keep_token_indices]
58-
59-
# offset by 1 token
39+
input_ids = model_inputs.input_ids
40+
seq_length = model_inputs.seq_length
41+
last_token_indices = spec_inputs.last_token_indices
42+
# # offset by 1 token
6043
input_ids[:, :-1] = input_ids[:, 1:].clone()
61-
# update next tokens
62-
if seq_length.size(0) == 1:
63-
input_ids[:, -1:] = spec_inputs.next_token_ids
64-
else:
65-
last_token_indices = seq_length.cumsum(0) - 1
66-
input_ids[:, last_token_indices] = spec_inputs.next_token_ids
44+
# # update next tokens
45+
input_ids[:, last_token_indices] = spec_inputs.next_token_ids
6746
# use new inputs
6847
return ModelInputs(
6948
input_ids=input_ids,
@@ -77,30 +56,5 @@ def prepare_inputs(self, model_inputs: ModelInputs, spec_inputs: SpecDecodeInput
7756
is_decoding=model_inputs.is_decoding,
7857
target_hidden_states=spec_inputs.target_hidden_states,
7958
target_position_ids=spec_inputs.target_position_ids,
80-
)
81-
82-
83-
@triton.jit
84-
def cal_token_indices(
85-
token_indices_ptr,
86-
cum_query_lens_ptr,
87-
cum_new_query_lens_ptr,
88-
BLOCK_SIZE: tl.constexpr,
89-
):
90-
"""Calculate the token indices based on rejection sampler results."""
91-
pid = tl.program_id(0)
92-
93-
start_pos = tl.load(cum_new_query_lens_ptr + pid)
94-
end_pos = tl.load(cum_new_query_lens_ptr + pid + 1)
95-
num_tokens = end_pos - start_pos
96-
97-
index_start = tl.load(cum_query_lens_ptr + pid)
98-
99-
num_blocks = tl.cdiv(num_tokens, BLOCK_SIZE)
100-
for i in tl.range(num_blocks):
101-
offset = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
102-
tl.store(
103-
token_indices_ptr + start_pos + offset,
104-
index_start + offset,
105-
mask=offset < num_tokens,
59+
spec_metadata=spec_metadata,
10660
)

lmdeploy/pytorch/spec_decode/eagle3.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from lmdeploy.utils import get_logger
77

88
from ..config import ModelConfig
9-
from ..model_inputs import ModelInputs
9+
from ..model_inputs import ModelInputs, SpecDecodeInputs
1010
from .base import SPEC_PROPOSERS
1111
from .deepseek_mtp import DeepseekMTP
1212

@@ -30,27 +30,26 @@ def get_target_hidden_size(self, model_config: ModelConfig):
3030
hidden_size = getattr(hf_config, 'target_hidden_size', hf_config.hidden_size)
3131
return hidden_size * 3
3232

33-
def get_outputs(self, model_outputs: Dict[str, torch.Tensor], model_inputs: ModelInputs):
33+
def get_outputs(self,
34+
model_outputs: Dict[str, torch.Tensor],
35+
model_inputs: ModelInputs,
36+
spec_inputs: SpecDecodeInputs = None):
3437
"""Get outputs."""
3538
hidden_states = model_outputs['hidden_states']
3639
hidden_states_prenorm = model_outputs['hidden_states_prenorm']
3740
model_metas = model_outputs['model_metas']
3841
if not model_inputs.is_decoding:
39-
if model_inputs.seq_length.size(0) == 1:
42+
assert spec_inputs is not None, 'spec_inputs should be provided for prefill mode'
43+
if model_inputs.seq_length.size(0) == 1 and spec_inputs.num_rejected_tokens is None:
4044
hidden_states = hidden_states[:, -1:]
4145
hidden_states_prenorm = hidden_states_prenorm[:, -1:]
4246
else:
43-
last_token_loc = model_inputs.seq_length.cumsum(0) - 1
47+
last_token_loc = spec_inputs.last_token_indices
4448
hidden_states = hidden_states[:, last_token_loc]
4549
hidden_states_prenorm = hidden_states_prenorm[:, last_token_loc]
4650

4751
logits = self.get_logits(hidden_states)[0]
4852
draft_token_ids = logits.argmax(dim=-1, keepdim=True)
49-
device = draft_token_ids.device
50-
dtype = draft_token_ids.dtype
5153
# token mapping
52-
if self.draft_id_to_target_id.device != device or self.draft_id_to_target_id.dtype != dtype:
53-
self.draft_id_to_target_id = self.draft_id_to_target_id.to(dtype=draft_token_ids.dtype,
54-
device=draft_token_ids.device)
5554
draft_token_ids = self.draft_id_to_target_id[draft_token_ids]
5655
return draft_token_ids, model_metas, hidden_states_prenorm

0 commit comments

Comments
 (0)