diff --git a/paddlenlp/transformers/chatglm_v2/modeling.py b/paddlenlp/transformers/chatglm_v2/modeling.py index d4b4a39758d9..bfe740616433 100644 --- a/paddlenlp/transformers/chatglm_v2/modeling.py +++ b/paddlenlp/transformers/chatglm_v2/modeling.py @@ -405,13 +405,16 @@ def forward( multiplier = self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition - key_layer = key_layer.unsqueeze(-2).tile([1, 1, 1, multiplier, 1]) + S, B, G, D = key_layer.shape + key_layer = key_layer.unsqueeze(-2) + key_layer = key_layer.expand(S, B, G, multiplier, D) key_layer = key_layer.reshape( - key_layer.shape[:2] + [self.num_attention_heads_per_partition, self.hidden_size_per_attention_head] + S, B, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head ) - value_layer = value_layer.unsqueeze(-2).tile([1, 1, 1, multiplier, 1]) + value_layer = value_layer.unsqueeze(-2) + value_layer = value_layer.expand(S, B, G, multiplier, D) value_layer = value_layer.reshape( - value_layer.shape[:2] + [self.num_attention_heads_per_partition, self.hidden_size_per_attention_head] + S, B, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head ) # ==================================