77import torch
88
99from rtp_llm .config .model_config import ModelConfig
10- from rtp_llm .ops import MlaOpsType
1110from rtp_llm .model_factory_register import register_model
1211from rtp_llm .model_loader .attn_weight import MlaAttnAtomicWeight , MlaConfig
1312from rtp_llm .model_loader .ffn_weight import (
2928)
3029from rtp_llm .models_py .model_desc .generic_moe import GenericMoeModel
3130from rtp_llm .models_py .model_desc .module_base import GptModelBase
31+ from rtp_llm .ops import MlaOpsType
3232from rtp_llm .utils .model_weight import (
3333 CkptWeightInfo ,
3434 W ,
@@ -73,7 +73,8 @@ def _get_hf_layer_weight_info(self, layer_id: int):
7373 kv_lora_rank = self .kv_lora_rank ,
7474 ope_head_dim = self .nope_head_dim ,
7575 v_head_dim = self .v_head_dim ,
76- use_mla = self .model_config .attn_config .use_mla and self .model_config .mla_ops_type != MlaOpsType .MHA ,
76+ use_mla = self .model_config .attn_config .use_mla
77+ and self .model_config .mla_ops_type != MlaOpsType .MHA ,
7778 q_use_lora = self .q_use_lora ,
7879 )
7980 layer_weights = [
@@ -225,7 +226,10 @@ def _get_hf_layer_weight_info(self, layer_id: int):
225226 )
226227 )
227228
228- if self .model_config .attn_config .use_mla and self .model_config .mla_ops_type != MlaOpsType .MHA :
229+ if (
230+ self .model_config .attn_config .use_mla
231+ and self .model_config .mla_ops_type != MlaOpsType .MHA
232+ ):
229233 mla_layer_weights .append (
230234 MlaAttnAtomicWeight (
231235 W .mla_kc ,
@@ -522,7 +526,7 @@ def _create_python_model(self) -> Optional[GptModelBase]:
522526 py_hw_kernel_config = self .hw_kernel_config
523527 moe_config = self .moe_config
524528 max_generate_batch_size = self .max_generate_batch_size
525-
529+
526530 # Use GenericMoeModel with new config architecture
527531 # attention_type is determined from model_config.attn_config.use_mla
528532 self .py_model = GenericMoeModel (
@@ -546,11 +550,13 @@ def _from_hf(config: ModelConfig, ckpt_path: str):
546550 config_json = json .loads (content )
547551 config .inter_size = config_json ["intermediate_size" ]
548552 config .attn_config .head_num = config_json ["num_attention_heads" ]
549- config .attn_config .kv_head_num = config_json .get ("num_key_value_heads" , config .attn_config .head_num )
553+ config .attn_config .kv_head_num = config_json .get (
554+ "num_key_value_heads" , config .attn_config .head_num
555+ )
550556 config .num_layers = config_json ["num_hidden_layers" ]
551- config .attn_config .rope_config .base = int (config_json . get (
552- "rope_theta" , config .attn_config .rope_config .base
553- ))
557+ config .attn_config .rope_config .base = int (
558+ config_json . get ( "rope_theta" , config .attn_config .rope_config .base )
559+ )
554560 config .vocab_size = config_json ["vocab_size" ]
555561 config .layernorm_eps = config_json .get ("rms_norm_eps" , 1e-06 )
556562 config .tie_word_embeddings = config_json .get ("tie_word_embeddings" , False )
@@ -559,13 +565,19 @@ def _from_hf(config: ModelConfig, ckpt_path: str):
559565 # MLA config
560566 config .attn_config .use_mla = True
561567 q_lora_rank = config_json .get ("q_lora_rank" )
562- config .attn_config .q_lora_rank = int (q_lora_rank ) if q_lora_rank is not None else 0
568+ config .attn_config .q_lora_rank = (
569+ int (q_lora_rank ) if q_lora_rank is not None else 0
570+ )
563571 kv_lora_rank = config_json .get ("kv_lora_rank" )
564- config .attn_config .kv_lora_rank = int (kv_lora_rank ) if kv_lora_rank is not None else 0
572+ config .attn_config .kv_lora_rank = (
573+ int (kv_lora_rank ) if kv_lora_rank is not None else 0
574+ )
565575 config .attn_config .nope_head_dim = config_json ["qk_nope_head_dim" ]
566576 config .attn_config .rope_head_dim = config_json ["qk_rope_head_dim" ]
567577 config .attn_config .v_head_dim = config_json ["v_head_dim" ]
568- config .attn_config .size_per_head = config .attn_config .nope_head_dim + config .attn_config .rope_head_dim
578+ config .attn_config .size_per_head = (
579+ config .attn_config .nope_head_dim + config .attn_config .rope_head_dim
580+ )
569581 config .attn_config .rope_config .dim = config .attn_config .rope_head_dim
570582
571583 # yarn rotary config
@@ -575,8 +587,12 @@ def _from_hf(config: ModelConfig, ckpt_path: str):
575587 config .attn_config .rope_config .style = 5
576588 rope_scaling = config_json .get ("rope_scaling" )
577589 config .attn_config .rope_config .scale = rope_scaling ["factor" ]
578- config .attn_config .rope_config .factor1 = float (rope_scaling .get ("beta_slow" , 1 ))
579- config .attn_config .rope_config .factor2 = float (rope_scaling .get ("beta_fast" , 32 ))
590+ config .attn_config .rope_config .factor1 = float (
591+ rope_scaling .get ("beta_slow" , 1 )
592+ )
593+ config .attn_config .rope_config .factor2 = float (
594+ rope_scaling .get ("beta_fast" , 32 )
595+ )
580596 config .attn_config .rope_config .max_pos = rope_scaling [
581597 "original_max_position_embeddings"
582598 ]
@@ -636,8 +652,25 @@ def get_weight_cls():
636652
637653class DeepSeekV3MtpWeight (DeepSeekV2Weight ):
638654
639- def __init__ (self , model_config : ModelConfig , parallelism_config , hw_kernel_config , kv_cache_config , merge_lora : bool = False , vit_config = None , ** kwargs ):
640- super ().__init__ (model_config = model_config , parallelism_config = parallelism_config , hw_kernel_config = hw_kernel_config , kv_cache_config = kv_cache_config , merge_lora = merge_lora , vit_config = vit_config , ** kwargs )
655+ def __init__ (
656+ self ,
657+ model_config : ModelConfig ,
658+ parallelism_config ,
659+ hw_kernel_config ,
660+ kv_cache_config ,
661+ merge_lora : bool = False ,
662+ vit_config = None ,
663+ ** kwargs ,
664+ ):
665+ super ().__init__ (
666+ model_config = model_config ,
667+ parallelism_config = parallelism_config ,
668+ hw_kernel_config = hw_kernel_config ,
669+ kv_cache_config = kv_cache_config ,
670+ merge_lora = merge_lora ,
671+ vit_config = vit_config ,
672+ ** kwargs ,
673+ )
641674
642675 def _get_weight_info (self ):
643676 layer_weights : List [List [WeightModule ]] = []
0 commit comments