5353from ..distributed import (AllReduce , AllReduceFusionOp , AllReduceParams ,
5454 MoEAllReduce , MoEAllReduceParams , allgather )
5555from ..model_config import ModelConfig
56- from ..models .modeling_utils import ModelConfig , QuantConfig
5756from ..modules .attention import MLA
5857from ..modules .decoder_layer import DecoderLayer
5958from ..modules .embedding import Embedding
6665from ..modules .multi_stream_utils import maybe_execute_in_parallel
6766from ..modules .rms_norm import RMSNorm
6867from ..peft .lora .layer import LoraLayer
69- from ..speculative import MTPEagleWorker , MTPSpecMetadata , MTPWorker
68+ from ..speculative import MTPSpecMetadata , SpecMetadata
7069from ..utils import AuxStreamType , EventType , Fp4QuantizedTensor
71- from .modeling_utils import ( DecoderModel , DecoderModelForCausalLM ,
72- EagerFusionConfig , filter_weights ,
70+ from .modeling_speculative import SpecDecOneEngineForCausalLM
71+ from . modeling_utils import ( DecoderModel , EagerFusionConfig , filter_weights ,
7372 register_auto_model )
7473
7574
@@ -541,7 +540,8 @@ def compute_routed_output(self, hidden_states, hidden_states_fp4,
541540 router_logits = self .gate (hidden_states )
542541
543542 routed_output = self .experts (
544- hidden_states_fp4 or hidden_states ,
543+ hidden_states_fp4
544+ if hidden_states_fp4 is not None else hidden_states ,
545545 router_logits ,
546546 do_finalize = do_finalize ,
547547 output_dtype = hidden_states .dtype ,
@@ -565,8 +565,9 @@ def forward(
565565 assert not self .use_dp
566566
567567 def _compute_shared_output ():
568- shared_output = self .shared_experts (hidden_states_fp4
569- or hidden_states )
568+ shared_output = self .shared_experts (
569+ hidden_states_fp4
570+ if hidden_states_fp4 is not None else hidden_states )
570571 if self .shared_output_scale is not None :
571572 shared_output *= self .shared_output_scale
572573 return shared_output
@@ -750,7 +751,7 @@ def forward(
750751 attn_metadata : AttentionMetadata ,
751752 residual : torch .Tensor ,
752753 ** kwargs ,
753- ) -> torch .Tensor :
754+ ) -> Tuple [ torch .Tensor , torch . Tensor ] :
754755 if residual is None :
755756 residual = hidden_states
756757 hidden_states = self .input_layernorm (hidden_states )
@@ -782,7 +783,7 @@ def forward_MoE(
782783 hidden_states : torch .Tensor ,
783784 attn_metadata : AttentionMetadata ,
784785 residual : torch .Tensor ,
785- ) -> torch .Tensor :
786+ ) -> Tuple [ torch .Tensor , torch . Tensor ] :
786787
787788 def _run_MoE (hidden_states , hidden_states_fp4 , do_finalize ):
788789 return self .mlp (
@@ -866,7 +867,7 @@ def forward_mlp(
866867 self ,
867868 hidden_states : torch .Tensor ,
868869 residual : torch .Tensor ,
869- ) -> torch .Tensor :
870+ ) -> Tuple [ torch .Tensor , torch . Tensor ] :
870871
871872 if self .fusion_config .PRE_MLP_FUSION :
872873 act_fp4 , act_sf , residual = self .allreduce (
@@ -970,7 +971,7 @@ def forward(
970971 all_rank_num_tokens : Optional [List [int ]] = None ,
971972 all_rank_max_num_tokens : Optional [int ] = None ,
972973 ** kwargs ,
973- ) -> Tuple [ torch .Tensor , torch . Tensor ] :
974+ ) -> torch .Tensor :
974975
975976 def norm_embeds ():
976977 return self .enorm (embed_tokens (input_ids )) #emdedding
@@ -1085,6 +1086,8 @@ def forward(
10851086 input_ids : Optional [torch .IntTensor ] = None ,
10861087 position_ids : Optional [torch .IntTensor ] = None ,
10871088 inputs_embeds : Optional [torch .FloatTensor ] = None ,
1089+ spec_metadata : Optional [SpecMetadata ] = None ,
1090+ ** kwargs ,
10881091 ) -> torch .Tensor :
10891092 if (input_ids is None ) ^ (inputs_embeds is not None ):
10901093 raise ValueError (
@@ -1109,8 +1112,8 @@ def forward(
11091112
11101113
11111114@register_auto_model ("DeepseekV3ForCausalLM" )
1112- class DeepseekV3ForCausalLM (DecoderModelForCausalLM [DeepseekV3Model ,
1113- PretrainedConfig ]):
1115+ class DeepseekV3ForCausalLM (SpecDecOneEngineForCausalLM [DeepseekV3Model ,
1116+ PretrainedConfig ]):
11141117
11151118 def __init__ (self , model_config : ModelConfig [PretrainedConfig ]):
11161119 # Rename some keys of quant_config_dict to support legacy checkpoints
@@ -1125,10 +1128,9 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]):
11251128 model_config ._frozen = False
11261129 model_config .quant_config_dict = quant_config_dict
11271130 model_config ._frozen = True
1128- super ().__init__ (DeepseekV3Model (model_config ),
1129- config = model_config ,
1130- hidden_size = model_config .pretrained_config .hidden_size ,
1131- vocab_size = model_config .pretrained_config .vocab_size )
1131+
1132+ super ().__init__ (model = DeepseekV3Model (model_config ),
1133+ model_config = model_config )
11321134
11331135 self .model_nextn = 0
11341136 if model_config .spec_config is not None :
@@ -1138,23 +1140,7 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]):
11381140 assert ckpt_nextn > 0 , "There is not MTP modules in the checkpoint."
11391141 if ckpt_nextn == 1 and not model_config .spec_config .use_mtp_vanilla :
11401142 moe_load_balancer_set_repeated_for_next_layer (model_nextn )
1141- mtp_layer = DeepseekV3MTP (model_config , self .num_hidden_layers ,
1142- self .model .aux_stream_dict )
1143- self .model .layers .append (mtp_layer )
1144- self .epilogue .append (mtp_layer )
1145- self .mtp_worker = MTPEagleWorker (model_config .spec_config ,
1146- model_config )
11471143 else :
1148- mtp_layers = nn .ModuleList ([
1149- DeepseekV3MTP (model_config ,
1150- layer_idx + self .num_hidden_layers ,
1151- self .model .aux_stream_dict )
1152- for layer_idx in range (model_nextn )
1153- ])
1154- self .model .layers .extend (mtp_layers )
1155- self .epilogue .extend (mtp_layers )
1156- self .mtp_worker = MTPWorker (model_config .spec_config ,
1157- model_config )
11581144 # modify the QuantConfig to support duplicated mtp layers
11591145 if model_config .quant_config .exclude_modules is not None :
11601146 extend_exclude_modules = []
@@ -1172,7 +1158,9 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]):
11721158 ckpt_prefix , model_prefix ))
11731159 self .model_config .quant_config .exclude_modules .extend (
11741160 extend_exclude_modules )
1175- self .epilogue .append (self .mtp_worker )
1161+ self .model .layers .extend (self .draft_model .mtp_layers )
1162+ self .epilogue .extend (self .draft_model .mtp_layers )
1163+ self .epilogue .append (self .spec_worker )
11761164
11771165 def forward (
11781166 self ,
@@ -1185,40 +1173,13 @@ def forward(
11851173 ** kwargs ,
11861174 ) -> torch .Tensor :
11871175 attn_metadata .num_generations_per_batch = self .model_nextn + 1
1188- hidden_states = self .model (
1189- input_ids = input_ids ,
1190- attn_metadata = attn_metadata ,
1191- position_ids = position_ids ,
1192- inputs_embeds = inputs_embeds ,
1193- )
1194-
1195- if spec_metadata and spec_metadata .spec_dec_mode .is_mtp ():
1196- # get logits
1197- logits = self .logits_processor .forward (
1198- hidden_states [spec_metadata .gather_ids ],
1199- self .lm_head ,
1200- attn_metadata ,
1201- True ,
1202- )
1203- # get accepted tokens and next draft tokens
1204- return self .mtp_worker (
1205- input_ids = input_ids ,
1206- position_ids = position_ids ,
1207- hidden_states = hidden_states ,
1208- logits = logits ,
1209- lm_head = self .lm_head ,
1210- embed_tokens = self .model .embed_tokens ,
1211- attn_metadata = attn_metadata ,
1212- spec_metadata = spec_metadata ,
1213- mtp_layers = self .model .layers [self .num_hidden_layers :])
1214- else :
1215- logits = self .logits_processor .forward (
1216- hidden_states ,
1217- self .lm_head ,
1218- attn_metadata ,
1219- return_context_logits ,
1220- )
1221- return logits
1176+ return super ().forward (attn_metadata = attn_metadata ,
1177+ input_ids = input_ids ,
1178+ position_ids = position_ids ,
1179+ inputs_embeds = inputs_embeds ,
1180+ spec_metadata = spec_metadata ,
1181+ return_context_logits = return_context_logits ,
1182+ ** kwargs )
12221183
12231184 def load_weights (self , weights : Dict ):
12241185
0 commit comments