Skip to content

Commit 01d2c2a

Browse files
authored
[BugFix] fix reload gpt (#1945)
1 parent b2e98d8 commit 01d2c2a

File tree

1 file changed

+4
-8
lines changed

1 file changed

+4
-8
lines changed

paddlenlp/transformers/gpt/modeling.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -401,18 +401,14 @@ def __init__(self,
401401
self.word_embeddings = nn.Embedding(
402402
vocab_size,
403403
hidden_size,
404-
weight_attr=paddle.ParamAttr(
405-
name="word_embeddings",
406-
initializer=nn.initializer.Normal(
407-
mean=0.0, std=initializer_range)))
404+
weight_attr=paddle.ParamAttr(initializer=nn.initializer.Normal(
405+
mean=0.0, std=initializer_range)))
408406

409407
self.position_embeddings = nn.Embedding(
410408
max_position_embeddings,
411409
hidden_size,
412-
weight_attr=paddle.ParamAttr(
413-
name="pos_embeddings",
414-
initializer=nn.initializer.Normal(
415-
mean=0.0, std=initializer_range)))
410+
weight_attr=paddle.ParamAttr(initializer=nn.initializer.Normal(
411+
mean=0.0, std=initializer_range)))
416412

417413
self.dropout = nn.Dropout(hidden_dropout_prob)
418414

0 commit comments

Comments
 (0)