diff --git a/paddlenlp/transformers/gpt/modeling_auto.py b/paddlenlp/transformers/gpt/modeling_auto.py index e21067ba42c3..4c2ac39b3597 100644 --- a/paddlenlp/transformers/gpt/modeling_auto.py +++ b/paddlenlp/transformers/gpt/modeling_auto.py @@ -658,7 +658,7 @@ def __init__( config.hidden_size, ) self.word_embeddings.weight = dist.shard_tensor( - self.word_embeddings.weight, get_mesh(), [dist.Replicate(), dist.Replicate()] + self.word_embeddings.weight, get_mesh(), [dist.Replicate(), dist.Shard(1)] ) self.position_embeddings.weight = dist.shard_tensor( self.position_embeddings.weight, get_mesh(), [dist.Replicate(), dist.Shard(1)] @@ -699,6 +699,7 @@ def forward(self, input_ids, position_ids=None, inputs_embeddings=None): # The 'with' block ensures the correct seed context is used with seed_guard_context(current_seed): embeddings = self.dropout(embeddings) + embeddings = dist.reshard(embeddings, get_mesh(), [dist.Replicate(), dist.Replicate()]) return embeddings @@ -1176,7 +1177,7 @@ def __init__(self, config: GPTConfig, embedding_weights=None, ipp=None): shape=[config.vocab_size, config.hidden_size], dtype=paddle.get_default_dtype(), ) - self.weight = dist.shard_tensor(self.weight, get_mesh(self.ipp), [dist.Replicate(), dist.Shard(0)]) + self.weight = dist.shard_tensor(self.weight, get_mesh(self.ipp), [dist.Replicate(), dist.Shard(1)]) def forward(self, hidden_states, tensor_parallel_output=None): @@ -1187,7 +1188,7 @@ def forward(self, hidden_states, tensor_parallel_output=None): if tensor_parallel_output is None: tensor_parallel_output = self.config.tensor_parallel_output - y = dist.reshard(self.weight, get_mesh(self.ipp), [dist.Replicate(), dist.Shard(0)]) + y = dist.reshard(self.weight, get_mesh(self.ipp), [dist.Replicate(), dist.Shard(1)]) logits = paddle.matmul(hidden_states, y, transpose_y=self.transpose_y) return logits