Skip to content

Commit 4bcf19e

Browse files
committed
fix qwen2 tp16
1 parent 86223cd commit 4bcf19e

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

lightllm/models/llama/layer_infer/transformer_layer_infer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ def __init__(self, layer_num, network_config, mode=[]):
3939
super().__init__(layer_num, network_config, mode)
4040
self.eps_ = network_config["rms_norm_eps"]
4141
self.tp_q_head_num_ = network_config["num_attention_heads"] // self.tp_world_size_
42-
self.tp_k_head_num_ = network_config["num_key_value_heads"] // self.tp_world_size_
43-
self.tp_v_head_num_ = network_config["num_key_value_heads"] // self.tp_world_size_
42+
self.tp_k_head_num_ = max(network_config["num_key_value_heads"] // self.tp_world_size_, 1)
43+
self.tp_v_head_num_ = max(network_config["num_key_value_heads"] // self.tp_world_size_, 1)
4444
self.tp_o_head_num_ = self.tp_q_head_num_
4545
self.head_dim_ = network_config["hidden_size"] // network_config["num_attention_heads"]
4646
self.embed_dim_ = network_config["hidden_size"]

0 commit comments

Comments
 (0)