Skip to content

Commit db1f991

Browse files
authored
refine gpt (#3447)
1 parent e544a04 commit db1f991

File tree

2 files changed

+3
-1
lines changed

2 files changed

+3
-1
lines changed

paddlenlp/transformers/gpt/modeling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1182,7 +1182,7 @@ def prepare_inputs_for_generation(self,
11821182
# only last token for inputs_ids if cache is defined in kwargs
11831183
position_ids = kwargs.get("position_ids", None)
11841184
attention_mask = kwargs.get("attention_mask", None)
1185-
if attention_mask is not None and len(attention_mask.shape) == 4:
1185+
if attention_mask is not None and attention_mask.ndim == 4:
11861186
attention_mask = attention_mask[:, -1:, -1:, :]
11871187
if cache is not None:
11881188
input_ids = input_ids[:, -1].unsqueeze(-1)

tests/transformers/gpt/test_modeling.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,7 @@ def test_lm_generate_gpt(self):
593593
def test_gpt_sample(self):
594594
tokenizer = GPTTokenizer.from_pretrained("gpt2-en")
595595
model = GPTLMHeadModel.from_pretrained("gpt2-en")
596+
model.eval()
596597

597598
paddle.seed(128)
598599
np.random.seed(128)
@@ -631,6 +632,7 @@ def test_gpt_sample_max_time(self):
631632
# NOTE: duration changed sharply and can not be limit in a range for now.
632633
tokenizer = GPTTokenizer.from_pretrained("gpt2-en")
633634
model = GPTLMHeadModel.from_pretrained("gpt2-en")
635+
model.eval()
634636

635637
paddle.seed(0)
636638
np.random.seed(0)

0 commit comments

Comments
 (0)