Skip to content

Commit f4db5e6

Browse files
authored
[Bugfix][Model] Fix inference for Hunyuan dense models (vllm-project#25354)
Signed-off-by: anion <[email protected]> Signed-off-by: Anion <[email protected]>
1 parent 099aaee commit f4db5e6

File tree

1 file changed

+59
-47
lines changed

1 file changed

+59
-47
lines changed

vllm/model_executor/models/hunyuan_v1.py

Lines changed: 59 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -888,7 +888,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
888888
return loaded_params
889889

890890

891-
class HunYuanV1Base(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
891+
class HunyuanV1ModelBase(nn.Module, SupportsLoRA, SupportsPP):
892892
packed_modules_mapping = {
893893
"qkv_proj": [
894894
"q_proj",
@@ -930,6 +930,56 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
930930
else:
931931
self.lm_head = PPMissingLayer()
932932

933+
def forward(
934+
self,
935+
input_ids: torch.Tensor,
936+
positions: torch.Tensor,
937+
intermediate_tensors: Optional[IntermediateTensors] = None,
938+
inputs_embeds: Optional[torch.Tensor] = None,
939+
) -> Union[torch.Tensor, IntermediateTensors]:
940+
model_output = self.model(input_ids, positions, intermediate_tensors,
941+
inputs_embeds)
942+
return model_output
943+
944+
def compute_logits(
945+
self,
946+
hidden_states: torch.Tensor,
947+
) -> Optional[torch.Tensor]:
948+
logits = self.logits_processor(self.lm_head, hidden_states)
949+
return logits
950+
951+
def make_empty_intermediate_tensors(
952+
self, batch_size: int, dtype: torch.dtype,
953+
device: torch.device) -> IntermediateTensors:
954+
return IntermediateTensors({
955+
"hidden_states":
956+
torch.zeros((batch_size, self.config.hidden_size),
957+
dtype=dtype,
958+
device=device),
959+
"residual":
960+
torch.zeros((batch_size, self.config.hidden_size),
961+
dtype=dtype,
962+
device=device),
963+
})
964+
965+
def load_weights(self, weights: Iterable[tuple[str,
966+
torch.Tensor]]) -> set[str]:
967+
loader = AutoWeightsLoader(
968+
self,
969+
skip_prefixes=(["lm_head."]
970+
if self.config.tie_word_embeddings else None),
971+
)
972+
return loader.load_weights(weights)
973+
974+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
975+
return self.model.get_input_embeddings(input_ids)
976+
977+
978+
class HunYuanMoEV1Base(HunyuanV1ModelBase, MixtureOfExperts):
979+
980+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
981+
super().__init__(vllm_config=vllm_config, prefix=prefix)
982+
933983
# Set MoE hyperparameters
934984
self.expert_weights = []
935985
self.num_expert_groups = 1
@@ -988,57 +1038,19 @@ def update_physical_experts_metadata(
9881038
moe.n_redundant_experts = self.num_redundant_experts
9891039
moe.experts.update_expert_map()
9901040

991-
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
992-
return self.model.get_input_embeddings(input_ids)
1041+
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
1042+
return self.model.get_expert_mapping()
9931043

994-
def forward(
995-
self,
996-
input_ids: torch.Tensor,
997-
positions: torch.Tensor,
998-
intermediate_tensors: Optional[IntermediateTensors] = None,
999-
inputs_embeds: Optional[torch.Tensor] = None,
1000-
) -> Union[torch.Tensor, IntermediateTensors]:
1001-
model_output = self.model(input_ids, positions, intermediate_tensors,
1002-
inputs_embeds)
1003-
return model_output
10041044

1005-
def compute_logits(
1006-
self,
1007-
hidden_states: torch.Tensor,
1008-
) -> Optional[torch.Tensor]:
1009-
logits = self.logits_processor(self.lm_head, hidden_states)
1010-
return logits
1045+
class HunYuanDenseV1Base(HunyuanV1ModelBase):
10111046

1012-
def make_empty_intermediate_tensors(
1013-
self, batch_size: int, dtype: torch.dtype,
1014-
device: torch.device) -> IntermediateTensors:
1015-
return IntermediateTensors({
1016-
"hidden_states":
1017-
torch.zeros((batch_size, self.config.hidden_size),
1018-
dtype=dtype,
1019-
device=device),
1020-
"residual":
1021-
torch.zeros((batch_size, self.config.hidden_size),
1022-
dtype=dtype,
1023-
device=device),
1024-
})
1025-
1026-
def load_weights(self, weights: Iterable[tuple[str,
1027-
torch.Tensor]]) -> set[str]:
1028-
loader = AutoWeightsLoader(
1029-
self,
1030-
skip_prefixes=(["lm_head."]
1031-
if self.config.tie_word_embeddings else None),
1032-
)
1033-
return loader.load_weights(weights)
1034-
1035-
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
1036-
return self.model.get_expert_mapping()
1047+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
1048+
super().__init__(vllm_config=vllm_config, prefix=prefix)
10371049

10381050

1039-
class HunYuanDenseV1ForCausalLM(HunYuanV1Base):
1051+
class HunYuanDenseV1ForCausalLM(HunYuanDenseV1Base):
10401052
pass
10411053

10421054

1043-
class HunYuanMoEV1ForCausalLM(HunYuanV1Base):
1044-
pass
1055+
class HunYuanMoEV1ForCausalLM(HunYuanMoEV1Base):
1056+
pass

0 commit comments

Comments
 (0)