Skip to content

Commit cf497bb

Browse files
committed
fix batch infra
1 parent fb3e4c0 commit cf497bb

File tree

3 files changed

+7
-3
lines changed

3 files changed

+7
-3
lines changed

β€Žllm/predict/predictor.pyβ€Ž

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,11 +212,17 @@ def _preprocess(self, source):
212212
source = [source] if isinstance(source, str) else source
213213
source = [self.tokenizer.apply_chat_template(sentence, tokenize=False) for sentence in source]
214214

215+
return_position_ids = False
216+
return_attention_mask = False
217+
if len(source) > 1:
218+
return_position_ids = True
219+
return_attention_mask = True
215220
tokenized_source = self.tokenizer(
216221
source,
217222
max_length=self.config.src_length,
218223
truncation=True,
219-
return_position_ids=True if not isinstance(self.tokenizer, ChatGLMTokenizer) else False,
224+
return_position_ids=True if not isinstance(self.tokenizer, ChatGLMTokenizer) else return_position_ids,
225+
return_attention_mask=return_attention_mask,
220226
truncation_side="left",
221227
return_tensors=self.return_tensors,
222228
padding=True,

β€Žpaddlenlp/transformers/qwen2/modeling.pyβ€Ž

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1493,7 +1493,6 @@ def prepare_inputs_for_generation(
14931493
):
14941494
batch_size, seq_length = input_ids.shape
14951495
position_ids = kwargs.get("position_ids", paddle.arange(seq_length).expand((batch_size, seq_length)))
1496-
attention_mask = kwargs.get("attention_mask", None)
14971496
if past_key_values:
14981497
input_ids = input_ids[:, -1].unsqueeze(axis=-1)
14991498
position_ids = position_ids[:, -1].unsqueeze(-1)

β€Žpaddlenlp/transformers/qwen2_moe/modeling.pyβ€Ž

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1429,7 +1429,6 @@ def prepare_inputs_for_generation(
14291429
):
14301430
batch_size, seq_length = input_ids.shape
14311431
position_ids = kwargs.get("position_ids", paddle.arange(seq_length).expand((batch_size, seq_length)))
1432-
attention_mask = kwargs.get("attention_mask", None)
14331432
if past_key_values:
14341433
input_ids = input_ids[:, -1].unsqueeze(axis=-1)
14351434
position_ids = position_ids[:, -1].unsqueeze(-1)

0 commit comments

Comments
Β (0)