22from typing import Dict
33
44import torch
5- import triton
6- import triton .language as tl
75
86from lmdeploy .utils import get_logger
97
1614@SPEC_PROPOSERS .register_module (name = 'deepseek_mtp' )
1715class DeepseekMTP (BaseSpecProposer ):
1816
19- def get_outputs (self , model_outputs : Dict [str , torch .Tensor ], model_inputs : ModelInputs ):
17+ def get_outputs (self ,
18+ model_outputs : Dict [str , torch .Tensor ],
19+ model_inputs : ModelInputs ,
20+ spec_inputs : SpecDecodeInputs = None ):
2021 """Get outputs."""
2122 hidden_states = model_outputs ['hidden_states' ]
2223 model_metas = model_outputs ['model_metas' ]
2324 if not model_inputs .is_decoding :
24- if model_inputs .seq_length .size (0 ) == 1 :
25+ assert spec_inputs is not None , 'spec_inputs should be provided for prefill mode'
26+ if model_inputs .seq_length .size (0 ) == 1 and spec_inputs .num_rejected_tokens is None :
2527 hidden_states = hidden_states [:, - 1 :]
2628 else :
27- last_token_loc = model_inputs . seq_length . cumsum ( 0 ) - 1
29+ last_token_loc = spec_inputs . last_token_indices
2830 hidden_states = hidden_states [:, last_token_loc ]
31+
2932 logits = self .get_logits (hidden_states )[0 ]
3033 draft_token_ids = logits .argmax (dim = - 1 , keepdim = True )
3134 return draft_token_ids , model_metas , hidden_states
3235
3336 def prepare_inputs (self , model_inputs : ModelInputs , spec_inputs : SpecDecodeInputs ):
3437 """Prepare inputs."""
3538 spec_metadata = model_inputs .spec_metadata
36-
37- if spec_metadata .draft_token_ids is None :
38- input_ids = model_inputs .input_ids
39- seq_length = model_inputs .seq_length
40- else :
41- # select input ids
42- query_lens = model_inputs .seq_length
43- batch_size = query_lens .size (0 )
44- cum_query_lens = query_lens .new_zeros ((batch_size + 1 ), dtype = torch .long )
45- cum_qery_lens_new = query_lens .new_zeros ((batch_size + 1 ), dtype = torch .long )
46- torch .cumsum (query_lens , dim = 0 , out = cum_query_lens [1 :])
47- query_lens_new = query_lens - spec_inputs .num_rejected_tokens
48- torch .cumsum (query_lens_new , dim = 0 , out = cum_qery_lens_new [1 :])
49- keep_token_indices = query_lens .new_zeros (
50- model_inputs .input_ids .size (1 ) - spec_inputs .num_rejected_tokens .sum ())
51- cal_token_indices [(batch_size , )](keep_token_indices , cum_query_lens , cum_qery_lens_new , BLOCK_SIZE = 1024 )
52- input_ids = model_inputs .input_ids [:, keep_token_indices ]
53- seq_length = query_lens_new
54-
55- spec_inputs .target_hidden_states = spec_inputs .target_hidden_states [:, keep_token_indices ]
56- if spec_inputs .target_position_ids is not None :
57- spec_inputs .target_position_ids = spec_inputs .target_position_ids [:, keep_token_indices ]
58-
59- # offset by 1 token
39+ input_ids = model_inputs .input_ids
40+ seq_length = model_inputs .seq_length
41+ last_token_indices = spec_inputs .last_token_indices
42+ # # offset by 1 token
6043 input_ids [:, :- 1 ] = input_ids [:, 1 :].clone ()
61- # update next tokens
62- if seq_length .size (0 ) == 1 :
63- input_ids [:, - 1 :] = spec_inputs .next_token_ids
64- else :
65- last_token_indices = seq_length .cumsum (0 ) - 1
66- input_ids [:, last_token_indices ] = spec_inputs .next_token_ids
44+ # # update next tokens
45+ input_ids [:, last_token_indices ] = spec_inputs .next_token_ids
6746 # use new inputs
6847 return ModelInputs (
6948 input_ids = input_ids ,
@@ -77,30 +56,5 @@ def prepare_inputs(self, model_inputs: ModelInputs, spec_inputs: SpecDecodeInput
7756 is_decoding = model_inputs .is_decoding ,
7857 target_hidden_states = spec_inputs .target_hidden_states ,
7958 target_position_ids = spec_inputs .target_position_ids ,
80- )
81-
82-
83- @triton .jit
84- def cal_token_indices (
85- token_indices_ptr ,
86- cum_query_lens_ptr ,
87- cum_new_query_lens_ptr ,
88- BLOCK_SIZE : tl .constexpr ,
89- ):
90- """Calculate the token indices based on rejection sampler results."""
91- pid = tl .program_id (0 )
92-
93- start_pos = tl .load (cum_new_query_lens_ptr + pid )
94- end_pos = tl .load (cum_new_query_lens_ptr + pid + 1 )
95- num_tokens = end_pos - start_pos
96-
97- index_start = tl .load (cum_query_lens_ptr + pid )
98-
99- num_blocks = tl .cdiv (num_tokens , BLOCK_SIZE )
100- for i in tl .range (num_blocks ):
101- offset = i * BLOCK_SIZE + tl .arange (0 , BLOCK_SIZE )
102- tl .store (
103- token_indices_ptr + start_pos + offset ,
104- index_start + offset ,
105- mask = offset < num_tokens ,
59+ spec_metadata = spec_metadata ,
10660 )
0 commit comments