@@ -888,7 +888,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
888
888
return loaded_params
889
889
890
890
891
- class HunYuanV1Base (nn .Module , SupportsLoRA , SupportsPP , MixtureOfExperts ):
891
+ class HunyuanV1ModelBase (nn .Module , SupportsLoRA , SupportsPP ):
892
892
packed_modules_mapping = {
893
893
"qkv_proj" : [
894
894
"q_proj" ,
@@ -930,6 +930,56 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
930
930
else :
931
931
self .lm_head = PPMissingLayer ()
932
932
933
+ def forward (
934
+ self ,
935
+ input_ids : torch .Tensor ,
936
+ positions : torch .Tensor ,
937
+ intermediate_tensors : Optional [IntermediateTensors ] = None ,
938
+ inputs_embeds : Optional [torch .Tensor ] = None ,
939
+ ) -> Union [torch .Tensor , IntermediateTensors ]:
940
+ model_output = self .model (input_ids , positions , intermediate_tensors ,
941
+ inputs_embeds )
942
+ return model_output
943
+
944
+ def compute_logits (
945
+ self ,
946
+ hidden_states : torch .Tensor ,
947
+ ) -> Optional [torch .Tensor ]:
948
+ logits = self .logits_processor (self .lm_head , hidden_states )
949
+ return logits
950
+
951
+ def make_empty_intermediate_tensors (
952
+ self , batch_size : int , dtype : torch .dtype ,
953
+ device : torch .device ) -> IntermediateTensors :
954
+ return IntermediateTensors ({
955
+ "hidden_states" :
956
+ torch .zeros ((batch_size , self .config .hidden_size ),
957
+ dtype = dtype ,
958
+ device = device ),
959
+ "residual" :
960
+ torch .zeros ((batch_size , self .config .hidden_size ),
961
+ dtype = dtype ,
962
+ device = device ),
963
+ })
964
+
965
+ def load_weights (self , weights : Iterable [tuple [str ,
966
+ torch .Tensor ]]) -> set [str ]:
967
+ loader = AutoWeightsLoader (
968
+ self ,
969
+ skip_prefixes = (["lm_head." ]
970
+ if self .config .tie_word_embeddings else None ),
971
+ )
972
+ return loader .load_weights (weights )
973
+
974
+ def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
975
+ return self .model .get_input_embeddings (input_ids )
976
+
977
+
978
+ class HunYuanMoEV1Base (HunyuanV1ModelBase , MixtureOfExperts ):
979
+
980
+ def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
981
+ super ().__init__ (vllm_config = vllm_config , prefix = prefix )
982
+
933
983
# Set MoE hyperparameters
934
984
self .expert_weights = []
935
985
self .num_expert_groups = 1
@@ -988,57 +1038,19 @@ def update_physical_experts_metadata(
988
1038
moe .n_redundant_experts = self .num_redundant_experts
989
1039
moe .experts .update_expert_map ()
990
1040
991
- def get_input_embeddings (self , input_ids : torch . Tensor ) -> torch . Tensor :
992
- return self .model .get_input_embeddings ( input_ids )
1041
+ def get_expert_mapping (self ) -> list [ tuple [ str , str , int , str ]] :
1042
+ return self .model .get_expert_mapping ( )
993
1043
994
- def forward (
995
- self ,
996
- input_ids : torch .Tensor ,
997
- positions : torch .Tensor ,
998
- intermediate_tensors : Optional [IntermediateTensors ] = None ,
999
- inputs_embeds : Optional [torch .Tensor ] = None ,
1000
- ) -> Union [torch .Tensor , IntermediateTensors ]:
1001
- model_output = self .model (input_ids , positions , intermediate_tensors ,
1002
- inputs_embeds )
1003
- return model_output
1004
1044
1005
- def compute_logits (
1006
- self ,
1007
- hidden_states : torch .Tensor ,
1008
- ) -> Optional [torch .Tensor ]:
1009
- logits = self .logits_processor (self .lm_head , hidden_states )
1010
- return logits
1045
+ class HunYuanDenseV1Base (HunyuanV1ModelBase ):
1011
1046
1012
- def make_empty_intermediate_tensors (
1013
- self , batch_size : int , dtype : torch .dtype ,
1014
- device : torch .device ) -> IntermediateTensors :
1015
- return IntermediateTensors ({
1016
- "hidden_states" :
1017
- torch .zeros ((batch_size , self .config .hidden_size ),
1018
- dtype = dtype ,
1019
- device = device ),
1020
- "residual" :
1021
- torch .zeros ((batch_size , self .config .hidden_size ),
1022
- dtype = dtype ,
1023
- device = device ),
1024
- })
1025
-
1026
- def load_weights (self , weights : Iterable [tuple [str ,
1027
- torch .Tensor ]]) -> set [str ]:
1028
- loader = AutoWeightsLoader (
1029
- self ,
1030
- skip_prefixes = (["lm_head." ]
1031
- if self .config .tie_word_embeddings else None ),
1032
- )
1033
- return loader .load_weights (weights )
1034
-
1035
- def get_expert_mapping (self ) -> list [tuple [str , str , int , str ]]:
1036
- return self .model .get_expert_mapping ()
1047
+ def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
1048
+ super ().__init__ (vllm_config = vllm_config , prefix = prefix )
1037
1049
1038
1050
1039
- class HunYuanDenseV1ForCausalLM (HunYuanV1Base ):
1051
+ class HunYuanDenseV1ForCausalLM (HunYuanDenseV1Base ):
1040
1052
pass
1041
1053
1042
1054
1043
- class HunYuanMoEV1ForCausalLM (HunYuanV1Base ):
1044
- pass
1055
+ class HunYuanMoEV1ForCausalLM (HunYuanMoEV1Base ):
1056
+ pass
0 commit comments