Skip to content

Commit 43dab1a

Browse files
authored
fix multimodal padding_free prediction_step (#4839)
1 parent 88cbad1 commit 43dab1a

File tree

9 files changed

+28
-10
lines changed

9 files changed

+28
-10
lines changed

examples/deploy/agent/client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ def infer_stream(client, model: str, messages, tools):
5555
response = ''
5656
print(f'query: {query}\nresponse: ', end='')
5757
for chunk in gen:
58+
if chunk is None:
59+
continue
5860
delta = chunk.choices[0].delta.content
5961
response += delta
6062
print(delta, end='', flush=True)
@@ -68,6 +70,8 @@ def infer_stream(client, model: str, messages, tools):
6870
model=model, messages=messages, tools=tools, max_tokens=512, temperature=0, stream=True)
6971
print(f'query: {query}\nresponse2: ', end='')
7072
for chunk in gen:
73+
if chunk is None:
74+
continue
7175
print(chunk.choices[0].delta.content, end='', flush=True)
7276
print()
7377

examples/deploy/client/llm/chat/openai_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ def infer_stream(client, model: str, messages):
2020
gen = client.chat.completions.create(model=model, messages=messages, stream=True, temperature=0)
2121
print(f'messages: {messages}\nresponse: ', end='')
2222
for chunk in gen:
23+
if chunk is None:
24+
continue
2325
print(chunk.choices[0].delta.content, end='', flush=True)
2426
print()
2527

examples/deploy/client/mllm/openai_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ def infer_stream(client, model: str, messages):
2121
gen = client.chat.completions.create(model=model, messages=messages, stream=True, temperature=0)
2222
print(f'messages: {messages}\nresponse: ', end='')
2323
for chunk in gen:
24+
if chunk is None:
25+
continue
2426
print(chunk.choices[0].delta.content, end='', flush=True)
2527
print()
2628

swift/llm/template/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ def _extend_tokens(input_ids: List[int], labels: Optional[List[int]], replace_id
309309
added_tokens_len += token_len - 1
310310
return input_ids, labels
311311

312-
def training_step_context(self, model, inputs):
312+
def forward_context(self, model, inputs):
313313
return nullcontext()
314314

315315
@staticmethod

swift/llm/template/template/internvl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,14 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
5656
encoded['pixel_values'] = pixel_values
5757
return encoded
5858

59-
def training_step_context(self, model, inputs):
59+
def forward_context(self, model, inputs):
6060
model_name = model.language_model.__class__.__name__.lower()
6161
if self._packing and 'internlm2' in model_name:
6262
position_ids = inputs['position_ids']
6363
modeling_module = model.language_model.model.layers[0].attention.__class__
6464
return self._patch_flash_attention_forward(modeling_module, position_ids, use_new_func=True)
6565
else:
66-
return super().training_step_context(model, inputs)
66+
return super().forward_context(model, inputs)
6767

6868
def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
6969
embedding = model.get_input_embeddings()

swift/llm/template/template/qwen.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,9 +297,9 @@ def _get_new_tokens(i):
297297
encoded['labels'] = labels
298298
return encoded
299299

300-
def training_step_context(self, model, inputs):
300+
def forward_context(self, model, inputs):
301301
if 'real_position_ids' not in inputs:
302-
return super().training_step_context(model, inputs)
302+
return super().forward_context(model, inputs)
303303
if self.version == 'v2':
304304
from transformers.models.qwen2_vl import modeling_qwen2_vl as modeling_module
305305
elif self.version == 'v2_5':

swift/trainers/rlhf_trainer/dpo_trainer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,5 +125,10 @@ def get_per_token_logps(
125125

126126
def training_step(self, model, inputs, *args, **kwargs):
127127
inputs['_position_ids'] = inputs.get('position_ids')
128-
with self.template.training_step_context(self.model, inputs):
128+
with self.template.forward_context(self.model, inputs):
129129
return super().training_step(model, inputs, *args, **kwargs)
130+
131+
def prediction_step(self, model, inputs, *args, **kwargs):
132+
inputs['_position_ids'] = inputs.get('position_ids')
133+
with self.template.forward_context(self.model, inputs):
134+
return super().prediction_step(model, inputs, *args, **kwargs)

swift/trainers/rlhf_trainer/gkd_trainer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,10 @@ def training_step(self,
148148
inputs['attention_mask'] = new_attention_mask
149149
inputs['labels'] = new_labels
150150

151-
with self.template.training_step_context(self.model, inputs):
151+
with self.template.forward_context(self.model, inputs):
152152
loss = HFSFTTrainer.training_step(self, model, inputs, num_items_in_batch)
153153
return loss
154+
155+
def prediction_step(self, model, inputs, *args, **kwargs):
156+
with self.template.forward_context(self.model, inputs):
157+
return super().prediction_step(model, inputs, *args, **kwargs)

swift/trainers/trainers.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,9 @@ def prediction_step(
237237
**gen_kwargs,
238238
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
239239
if not self.args.predict_with_generate or prediction_loss_only:
240-
return super().prediction_step(
241-
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys)
240+
with self.template.forward_context(self.model, inputs):
241+
return super().prediction_step(
242+
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys)
242243
from swift.llm import RequestConfig, InferRequest
243244
data_list = inputs['_data']
244245
labels_list = [InferRequest.remove_response(data['messages']) for data in data_list]
@@ -340,5 +341,5 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
340341
return (loss, outputs) if return_outputs else loss
341342

342343
def training_step(self, model, inputs, *args, **kwargs):
343-
with self.template.training_step_context(self.model, inputs):
344+
with self.template.forward_context(self.model, inputs):
344345
return super().training_step(model, inputs, *args, **kwargs)

0 commit comments

Comments
 (0)