Skip to content

Commit 4aed7a7

Browse files
authored
[TRTLLM-6853][feat] refactor deepseekv3 model (#6698)
Signed-off-by: linquanh <linquanh@nvidia.com>
1 parent ffc976c commit 4aed7a7

File tree

5 files changed

+100
-107
lines changed

5 files changed

+100
-107
lines changed

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 29 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@
5353
from ..distributed import (AllReduce, AllReduceFusionOp, AllReduceParams,
5454
MoEAllReduce, MoEAllReduceParams, allgather)
5555
from ..model_config import ModelConfig
56-
from ..models.modeling_utils import ModelConfig, QuantConfig
5756
from ..modules.attention import MLA
5857
from ..modules.decoder_layer import DecoderLayer
5958
from ..modules.embedding import Embedding
@@ -66,10 +65,10 @@
6665
from ..modules.multi_stream_utils import maybe_execute_in_parallel
6766
from ..modules.rms_norm import RMSNorm
6867
from ..peft.lora.layer import LoraLayer
69-
from ..speculative import MTPEagleWorker, MTPSpecMetadata, MTPWorker
68+
from ..speculative import MTPSpecMetadata, SpecMetadata
7069
from ..utils import AuxStreamType, EventType, Fp4QuantizedTensor
71-
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
72-
EagerFusionConfig, filter_weights,
70+
from .modeling_speculative import SpecDecOneEngineForCausalLM
71+
from .modeling_utils import (DecoderModel, EagerFusionConfig, filter_weights,
7372
register_auto_model)
7473

7574

@@ -541,7 +540,8 @@ def compute_routed_output(self, hidden_states, hidden_states_fp4,
541540
router_logits = self.gate(hidden_states)
542541

543542
routed_output = self.experts(
544-
hidden_states_fp4 or hidden_states,
543+
hidden_states_fp4
544+
if hidden_states_fp4 is not None else hidden_states,
545545
router_logits,
546546
do_finalize=do_finalize,
547547
output_dtype=hidden_states.dtype,
@@ -565,8 +565,9 @@ def forward(
565565
assert not self.use_dp
566566

567567
def _compute_shared_output():
568-
shared_output = self.shared_experts(hidden_states_fp4
569-
or hidden_states)
568+
shared_output = self.shared_experts(
569+
hidden_states_fp4
570+
if hidden_states_fp4 is not None else hidden_states)
570571
if self.shared_output_scale is not None:
571572
shared_output *= self.shared_output_scale
572573
return shared_output
@@ -750,7 +751,7 @@ def forward(
750751
attn_metadata: AttentionMetadata,
751752
residual: torch.Tensor,
752753
**kwargs,
753-
) -> torch.Tensor:
754+
) -> Tuple[torch.Tensor, torch.Tensor]:
754755
if residual is None:
755756
residual = hidden_states
756757
hidden_states = self.input_layernorm(hidden_states)
@@ -782,7 +783,7 @@ def forward_MoE(
782783
hidden_states: torch.Tensor,
783784
attn_metadata: AttentionMetadata,
784785
residual: torch.Tensor,
785-
) -> torch.Tensor:
786+
) -> Tuple[torch.Tensor, torch.Tensor]:
786787

787788
def _run_MoE(hidden_states, hidden_states_fp4, do_finalize):
788789
return self.mlp(
@@ -866,7 +867,7 @@ def forward_mlp(
866867
self,
867868
hidden_states: torch.Tensor,
868869
residual: torch.Tensor,
869-
) -> torch.Tensor:
870+
) -> Tuple[torch.Tensor, torch.Tensor]:
870871

871872
if self.fusion_config.PRE_MLP_FUSION:
872873
act_fp4, act_sf, residual = self.allreduce(
@@ -970,7 +971,7 @@ def forward(
970971
all_rank_num_tokens: Optional[List[int]] = None,
971972
all_rank_max_num_tokens: Optional[int] = None,
972973
**kwargs,
973-
) -> Tuple[torch.Tensor, torch.Tensor]:
974+
) -> torch.Tensor:
974975

975976
def norm_embeds():
976977
return self.enorm(embed_tokens(input_ids)) #emdedding
@@ -1085,6 +1086,8 @@ def forward(
10851086
input_ids: Optional[torch.IntTensor] = None,
10861087
position_ids: Optional[torch.IntTensor] = None,
10871088
inputs_embeds: Optional[torch.FloatTensor] = None,
1089+
spec_metadata: Optional[SpecMetadata] = None,
1090+
**kwargs,
10881091
) -> torch.Tensor:
10891092
if (input_ids is None) ^ (inputs_embeds is not None):
10901093
raise ValueError(
@@ -1109,8 +1112,8 @@ def forward(
11091112

11101113

11111114
@register_auto_model("DeepseekV3ForCausalLM")
1112-
class DeepseekV3ForCausalLM(DecoderModelForCausalLM[DeepseekV3Model,
1113-
PretrainedConfig]):
1115+
class DeepseekV3ForCausalLM(SpecDecOneEngineForCausalLM[DeepseekV3Model,
1116+
PretrainedConfig]):
11141117

11151118
def __init__(self, model_config: ModelConfig[PretrainedConfig]):
11161119
# Rename some keys of quant_config_dict to support legacy checkpoints
@@ -1125,10 +1128,9 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]):
11251128
model_config._frozen = False
11261129
model_config.quant_config_dict = quant_config_dict
11271130
model_config._frozen = True
1128-
super().__init__(DeepseekV3Model(model_config),
1129-
config=model_config,
1130-
hidden_size=model_config.pretrained_config.hidden_size,
1131-
vocab_size=model_config.pretrained_config.vocab_size)
1131+
1132+
super().__init__(model=DeepseekV3Model(model_config),
1133+
model_config=model_config)
11321134

11331135
self.model_nextn = 0
11341136
if model_config.spec_config is not None:
@@ -1138,23 +1140,7 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]):
11381140
assert ckpt_nextn > 0, "There is not MTP modules in the checkpoint."
11391141
if ckpt_nextn == 1 and not model_config.spec_config.use_mtp_vanilla:
11401142
moe_load_balancer_set_repeated_for_next_layer(model_nextn)
1141-
mtp_layer = DeepseekV3MTP(model_config, self.num_hidden_layers,
1142-
self.model.aux_stream_dict)
1143-
self.model.layers.append(mtp_layer)
1144-
self.epilogue.append(mtp_layer)
1145-
self.mtp_worker = MTPEagleWorker(model_config.spec_config,
1146-
model_config)
11471143
else:
1148-
mtp_layers = nn.ModuleList([
1149-
DeepseekV3MTP(model_config,
1150-
layer_idx + self.num_hidden_layers,
1151-
self.model.aux_stream_dict)
1152-
for layer_idx in range(model_nextn)
1153-
])
1154-
self.model.layers.extend(mtp_layers)
1155-
self.epilogue.extend(mtp_layers)
1156-
self.mtp_worker = MTPWorker(model_config.spec_config,
1157-
model_config)
11581144
# modify the QuantConfig to support duplicated mtp layers
11591145
if model_config.quant_config.exclude_modules is not None:
11601146
extend_exclude_modules = []
@@ -1172,7 +1158,9 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]):
11721158
ckpt_prefix, model_prefix))
11731159
self.model_config.quant_config.exclude_modules.extend(
11741160
extend_exclude_modules)
1175-
self.epilogue.append(self.mtp_worker)
1161+
self.model.layers.extend(self.draft_model.mtp_layers)
1162+
self.epilogue.extend(self.draft_model.mtp_layers)
1163+
self.epilogue.append(self.spec_worker)
11761164

11771165
def forward(
11781166
self,
@@ -1185,40 +1173,13 @@ def forward(
11851173
**kwargs,
11861174
) -> torch.Tensor:
11871175
attn_metadata.num_generations_per_batch = self.model_nextn + 1
1188-
hidden_states = self.model(
1189-
input_ids=input_ids,
1190-
attn_metadata=attn_metadata,
1191-
position_ids=position_ids,
1192-
inputs_embeds=inputs_embeds,
1193-
)
1194-
1195-
if spec_metadata and spec_metadata.spec_dec_mode.is_mtp():
1196-
# get logits
1197-
logits = self.logits_processor.forward(
1198-
hidden_states[spec_metadata.gather_ids],
1199-
self.lm_head,
1200-
attn_metadata,
1201-
True,
1202-
)
1203-
# get accepted tokens and next draft tokens
1204-
return self.mtp_worker(
1205-
input_ids=input_ids,
1206-
position_ids=position_ids,
1207-
hidden_states=hidden_states,
1208-
logits=logits,
1209-
lm_head=self.lm_head,
1210-
embed_tokens=self.model.embed_tokens,
1211-
attn_metadata=attn_metadata,
1212-
spec_metadata=spec_metadata,
1213-
mtp_layers=self.model.layers[self.num_hidden_layers:])
1214-
else:
1215-
logits = self.logits_processor.forward(
1216-
hidden_states,
1217-
self.lm_head,
1218-
attn_metadata,
1219-
return_context_logits,
1220-
)
1221-
return logits
1176+
return super().forward(attn_metadata=attn_metadata,
1177+
input_ids=input_ids,
1178+
position_ids=position_ids,
1179+
inputs_embeds=inputs_embeds,
1180+
spec_metadata=spec_metadata,
1181+
return_context_logits=return_context_logits,
1182+
**kwargs)
12221183

12231184
def load_weights(self, weights: Dict):
12241185

tensorrt_llm/_torch/models/modeling_speculative.py

Lines changed: 50 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44
from torch import nn
5-
from transformers import LlamaConfig
5+
from transformers import LlamaConfig, PretrainedConfig
66

77
from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \
88
BaseWeightMapper
@@ -320,14 +320,45 @@ def apply_eagle3_fc(self, hidden_states: torch.Tensor) -> torch.Tensor:
320320
return hidden_states
321321

322322

323-
def get_draft_model(model_config, draft_config):
323+
class MTPForCausalLM(nn.Module):
324+
325+
def __init__(
326+
self,
327+
model_config: ModelConfig[PretrainedConfig],
328+
start_layer_idx: int = 0,
329+
lm_head: nn.Module = None,
330+
model: nn.Module = None,
331+
):
332+
super().__init__()
333+
# Import here to avoid circular import
334+
from .modeling_deepseekv3 import DeepseekV3MTP
335+
336+
spec_dec_mode = model_config.spec_config.spec_dec_mode
337+
assert spec_dec_mode.is_mtp()
338+
mtp_num_layers = 1 if spec_dec_mode.is_mtp_eagle(
339+
) else model_config.spec_config.num_nextn_predict_layers
340+
341+
self.mtp_layers = nn.ModuleList([
342+
DeepseekV3MTP(model_config, layer_idx + start_layer_idx,
343+
model.aux_stream_dict)
344+
for layer_idx in range(mtp_num_layers)
345+
])
346+
self.lm_head = lm_head
347+
self.embed_tokens = model.embed_tokens
348+
349+
350+
def get_draft_model(model_config, draft_config, lm_head, model):
324351
assert getattr(model_config, 'spec_config', None) != None
325352
spec_dec_mode = model_config.spec_config.spec_dec_mode
326353
if spec_dec_mode.is_eagle3_one_model():
327354
return Eagle3ForCausalLM(
328355
draft_config, model_config.pretrained_config.num_hidden_layers)
356+
elif spec_dec_mode.is_mtp():
357+
return MTPForCausalLM(model_config,
358+
model_config.pretrained_config.num_hidden_layers,
359+
lm_head, model)
329360
else:
330-
raise NotImplemented(
361+
raise NotImplementedError(
331362
f"get_draft_model does not support speculative decoding mode {spec_dec_mode}."
332363
)
333364

@@ -341,23 +372,24 @@ def __init__(self, model: TModel, model_config: ModelConfig[TConfig]):
341372
hidden_size=model_config.pretrained_config.hidden_size,
342373
vocab_size=model_config.pretrained_config.vocab_size)
343374
self.draft_model = None
344-
if getattr(
345-
model_config, 'spec_config', None
346-
) and model_config.spec_config.spec_dec_mode.use_one_engine():
347-
draft_config = ModelConfig.from_pretrained(
348-
model_config.spec_config.speculative_model_dir,
349-
trust_remote_code=True,
350-
attn_backend=model_config.attn_backend,
351-
moe_backend=model_config.moe_backend,
352-
mapping=model_config.mapping,
353-
spec_config=model_config.spec_config,
354-
max_num_tokens=model_config.max_num_tokens,
355-
moe_max_num_tokens=model_config.moe_max_num_tokens)
356-
357-
draft_config.quant_config.kv_cache_quant_algo = \
375+
spec_config = getattr(model_config, 'spec_config', None)
376+
if spec_config and spec_config.spec_dec_mode.use_one_engine():
377+
draft_config = None
378+
if spec_config.spec_dec_mode.is_eagle3_one_model():
379+
draft_config = ModelConfig.from_pretrained(
380+
model_config.spec_config.speculative_model_dir,
381+
trust_remote_code=True,
382+
attn_backend=model_config.attn_backend,
383+
moe_backend=model_config.moe_backend,
384+
mapping=model_config.mapping,
385+
spec_config=model_config.spec_config,
386+
max_num_tokens=model_config.max_num_tokens,
387+
moe_max_num_tokens=model_config.moe_max_num_tokens)
388+
draft_config.quant_config.kv_cache_quant_algo = \
358389
model_config.quant_config.kv_cache_quant_algo
359390

360-
self.draft_model = get_draft_model(model_config, draft_config)
391+
self.draft_model = get_draft_model(model_config, draft_config,
392+
self.lm_head, self.model)
361393
self.spec_worker = get_spec_worker(model_config.spec_config,
362394
model_config,
363395
model_config.mapping)

tensorrt_llm/_torch/speculative/interface.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ class SpeculativeDecodingMode(IntEnum):
2323
def is_mtp(self):
2424
return self == SpeculativeDecodingMode.MTP or self == SpeculativeDecodingMode.MTP_EAGLE
2525

26+
def is_mtp_vanilla(self):
27+
return self == SpeculativeDecodingMode.MTP
28+
2629
def is_mtp_eagle(self):
2730
return self == SpeculativeDecodingMode.MTP_EAGLE
2831

tensorrt_llm/_torch/speculative/mtp.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -330,11 +330,9 @@ def forward(
330330
position_ids,
331331
hidden_states,
332332
logits,
333-
lm_head,
334-
embed_tokens,
335333
attn_metadata,
336334
spec_metadata,
337-
mtp_layers,
335+
draft_model,
338336
):
339337
'''
340338
Example:
@@ -470,9 +468,10 @@ def forward(
470468
next_draft_tokens = []
471469
last_tokens_idx = torch.cumsum(
472470
attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1
473-
for _, mtp_layer in enumerate(mtp_layers):
474-
hidden_states = mtp_layer(embed_tokens=embed_tokens, **draft_inputs)
475-
logits = mtp_layer.shared_head(hidden_states, lm_head,
471+
for _, mtp_layer in enumerate(draft_model.mtp_layers):
472+
hidden_states = mtp_layer(embed_tokens=draft_model.embed_tokens,
473+
**draft_inputs)
474+
logits = mtp_layer.shared_head(hidden_states, draft_model.lm_head,
476475
attn_metadata).float()
477476
new_draft_token = self.draft_sampler(logits)
478477
next_draft_tokens.append(new_draft_token)
@@ -517,11 +516,9 @@ def skip_forward(
517516
position_ids,
518517
hidden_states,
519518
logits,
520-
lm_head,
521-
embed_tokens,
522519
attn_metadata,
523520
spec_metadata,
524-
mtp_layers,
521+
draft_model,
525522
):
526523
batch_size = attn_metadata.num_seqs
527524
mtp_num_modules = self.spec_config.num_nextn_predict_layers
@@ -1127,11 +1124,9 @@ def forward(
11271124
position_ids,
11281125
hidden_states,
11291126
logits,
1130-
lm_head,
1131-
embed_tokens,
11321127
attn_metadata,
11331128
spec_metadata,
1134-
mtp_layers,
1129+
draft_model,
11351130
):
11361131
batch_size = attn_metadata.num_seqs
11371132
num_contexts = attn_metadata.num_contexts
@@ -1172,8 +1167,8 @@ def prepare_position_ids_and_last_tokens(position_ids, attn_metadata):
11721167
next_draft_tokens = []
11731168
for i in range(self.mtp_num_modules):
11741169
if i == 0:
1175-
hidden_states = mtp_layers[0](
1176-
embed_tokens=embed_tokens,
1170+
hidden_states = draft_model.mtp_layers[0](
1171+
embed_tokens=draft_model.embed_tokens,
11771172
all_rank_num_tokens=spec_metadata.all_rank_num_tokens,
11781173
all_rank_max_num_tokens=spec_metadata.
11791174
all_rank_max_num_tokens,
@@ -1186,8 +1181,8 @@ def prepare_position_ids_and_last_tokens(position_ids, attn_metadata):
11861181
gather_ids = torch.concat(
11871182
[last_tokens_idx[:num_contexts], gather_ids_gen], dim=0)
11881183
else:
1189-
hidden_states = mtp_layers[0](
1190-
embed_tokens=embed_tokens,
1184+
hidden_states = draft_model.mtp_layers[0](
1185+
embed_tokens=draft_model.embed_tokens,
11911186
all_rank_num_tokens=spec_metadata.
11921187
subseq_all_rank_num_tokens,
11931188
all_rank_max_num_tokens=max(
@@ -1197,8 +1192,9 @@ def prepare_position_ids_and_last_tokens(position_ids, attn_metadata):
11971192
**inputs)
11981193
# All of the seq_len are 1, use batch_indices_cuda as gather_ids
11991194
gather_ids = spec_metadata.batch_indices_cuda[:batch_size]
1200-
logits = mtp_layers[0].shared_head(hidden_states[gather_ids],
1201-
lm_head, attn_metadata, True)
1195+
logits = draft_model.mtp_layers[0].shared_head(
1196+
hidden_states[gather_ids], draft_model.lm_head, attn_metadata,
1197+
True)
12021198
new_draft_token = self.draft_sampler(logits)
12031199

12041200
hidden_states, position_ids = self.update_draft_tokens(

0 commit comments

Comments
 (0)