Skip to content

Commit a7a7251

Browse files
authored
decoder_weight -> decoder.weight (#6117)
1 parent 0e58759 commit a7a7251

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

paddlenlp/ops/fast_transformer/transformer/decoding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3105,11 +3105,11 @@ def __init__(self, model, decoding_lib=None, use_fp16_decoding=False):
31053105
params["pos_emb"].append((self.model.gpt.embeddings.position_embeddings, "weight"))
31063106

31073107
# if model share word_embeddings weight
3108-
if id(self.model.gpt.embeddings.word_embeddings) == id(self.model.lm_head.decoder_weight):
3108+
if id(self.model.gpt.embeddings.word_embeddings) == id(self.model.lm_head.decoder.weight):
31093109
params["linear_weight"].append((self.model.gpt.embeddings.word_embeddings, "weight"))
31103110
else:
31113111
params["linear_weight"].append(
3112-
(self.model.lm_head.decoder_weight, False, partial(setattr, self, "decoder_weight"))
3112+
(self.model.lm_head.decoder.weight, False, partial(setattr, self, "weight"))
31133113
)
31143114

31153115
for k, v in params.items():

0 commit comments

Comments
 (0)