Skip to content

Commit 55db2ff

Browse files
authored
[Inference] Fix multibatch inference (PaddlePaddle#9831)
* fix batch infra * fix deepseekv2 infra
1 parent eab22f2 commit 55db2ff

File tree

4 files changed

+7
-4
lines changed

4 files changed

+7
-4
lines changed

β€Žllm/predict/predictor.pyβ€Ž

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,11 +221,17 @@ def _preprocess(self, source):
221221
source = [source] if isinstance(source, str) else source
222222
source = [self.tokenizer.apply_chat_template(sentence, tokenize=False) for sentence in source]
223223

224+
return_position_ids = False
225+
return_attention_mask = False
226+
if len(source) > 1:
227+
return_position_ids = True
228+
return_attention_mask = True
224229
tokenized_source = self.tokenizer(
225230
source,
226231
max_length=self.config.src_length,
227232
truncation=True,
228-
return_position_ids=True if not isinstance(self.tokenizer, ChatGLMTokenizer) else False,
233+
return_position_ids=True if not isinstance(self.tokenizer, ChatGLMTokenizer) else return_position_ids,
234+
return_attention_mask=return_attention_mask,
229235
truncation_side="left",
230236
return_tensors=self.return_tensors,
231237
padding=True,

β€Žpaddlenlp/transformers/deepseek_v2/modeling.pyβ€Ž

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1772,7 +1772,6 @@ def prepare_inputs_for_generation(
17721772
):
17731773
batch_size, seq_length = input_ids.shape
17741774
position_ids = kwargs.get("position_ids", paddle.arange(seq_length).expand((batch_size, seq_length)))
1775-
attention_mask = kwargs.get("attention_mask", None)
17761775
if past_key_values:
17771776
input_ids = input_ids[:, -1].unsqueeze(axis=-1)
17781777
position_ids = position_ids[:, -1].unsqueeze(-1)

β€Žpaddlenlp/transformers/qwen2/modeling.pyβ€Ž

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1493,7 +1493,6 @@ def prepare_inputs_for_generation(
14931493
):
14941494
batch_size, seq_length = input_ids.shape
14951495
position_ids = kwargs.get("position_ids", paddle.arange(seq_length).expand((batch_size, seq_length)))
1496-
attention_mask = kwargs.get("attention_mask", None)
14971496
if past_key_values:
14981497
input_ids = input_ids[:, -1].unsqueeze(axis=-1)
14991498
position_ids = position_ids[:, -1].unsqueeze(-1)

β€Žpaddlenlp/transformers/qwen2_moe/modeling.pyβ€Ž

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1429,7 +1429,6 @@ def prepare_inputs_for_generation(
14291429
):
14301430
batch_size, seq_length = input_ids.shape
14311431
position_ids = kwargs.get("position_ids", paddle.arange(seq_length).expand((batch_size, seq_length)))
1432-
attention_mask = kwargs.get("attention_mask", None)
14331432
if past_key_values:
14341433
input_ids = input_ids[:, -1].unsqueeze(axis=-1)
14351434
position_ids = position_ids[:, -1].unsqueeze(-1)

0 commit comments

Comments
Β (0)