Skip to content

Commit 3b5086c

Browse files
hjh0119Jintao-Huang
authored andcommitted
[bugfix] fix grpo mllm multi turn (#5840)
1 parent e5d7c0f commit 3b5086c

File tree

4 files changed

+47
-27
lines changed

4 files changed

+47
-27
lines changed

docs/source/Instruction/GRPO/DeveloperGuide/多轮训练.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,15 @@ swift rollout \
163163

164164
可以在kwargs中获取 trajectory_inputs 获取完整轨迹的数据,具体实现参考[MultiTurnThinkingTips类](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/plugin/plugin.py)
165165

166+
### 多模态数据修改
167+
在多模态多轮交互场景下,可能需要在对话过程中动态增删或修改多模态数据,并确保这些变更同步至 trainer。
168+
169+
实现方式:借助 rollout_infos,通过指定键值覆盖原始数据集的多模态内容。
170+
171+
现已支持覆盖的键:images、audios、videos。
172+
173+
具体请参考[DeepEyes Schduler](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/plugin/deepeyes/deepeyes_plugin.py#L403-L404)
174+
166175
### 返回 response token ids
167176
在默认的多轮交互流程中,规划器先把模型生成的文本字符串返回给 trainer,trainer 再将其重新 encode 为 token id,用于后续训练。为了避免这一步重复编码的开销,你可以让规划器直接返回 response_token_ids,省去 trainer 侧的再次 encode。
168177

docs/source_en/Instruction/GRPO/DeveloperGuide/multi_turn.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,15 @@ The complete trajectory can be accessed via `trajectory_inputs` in `kwargs`.
172172

173173
For a concrete implementation, see the [MultiTurnThinkingTips class](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/plugin/plugin.py)
174174

175+
### Multimodal Data Override
176+
In multimodal, multi-turn interactions, you may need to dynamically add, delete, or modify multimodal data during the conversation and ensure these changes are synchronized to the trainer.
177+
178+
Implementation: Use `rollout_infos` to override the original multimodal content in the dataset by specifying the corresponding keys.
179+
180+
Supported override keys: images, audios, videos.
181+
182+
For details, see [DeepEyes Scheduler](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/plugin/deepeyes/deepeyes_plugin.py#L403-L404).
183+
175184
### Returning response token IDs
176185

177186
In the default workflow the scheduler returns text, the trainer re-encodes it to token IDs for training.

examples/train/grpo/plugin/deepeyes/deepeyes_plugin.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,7 @@ def step(self, infer_request, response_choice, current_turn):
400400
infer_request.messages.append({'role': 'user', 'content': query})
401401
if cropped_img:
402402
infer_request.images.append(cropped_img)
403+
# override the images
403404
extra_info['images'] = infer_request.images
404405

405406
# Return dictionary format according to new MultiTurnScheduler interface

swift/trainers/rlhf_trainer/grpo_trainer.py

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1542,12 +1542,7 @@ def _compute_loss_chunked(self, model, inputs: DataType):
15421542
end_idx = min(start_idx + new_chunk_size, batch_size)
15431543

15441544
if start_idx < batch_size:
1545-
# Create chunk inputs
1546-
for key, value in inputs.items():
1547-
if isinstance(value, torch.Tensor):
1548-
chunk_inputs[key] = value[start_idx:end_idx]
1549-
else:
1550-
chunk_inputs[key] = value
1545+
chunk_inputs = self.get_chunked_inputs(inputs, start_idx, end_idx)
15511546

15521547
# Compute loss and metrics for this chunk (without updating global metrics)
15531548
chunk_loss, chunk_metrics_data = self._compute_loss_and_metrics(model, chunk_inputs)
@@ -1816,26 +1811,6 @@ def _get_per_token_logps_and_entropies_chunked(self,
18161811
``False``.
18171812
"""
18181813

1819-
def get_chunked_inputs(inputs, start_idx, end_idx):
1820-
chunk_inputs = {}
1821-
if not self.is_multimodal:
1822-
# for LLM, slice the inputs
1823-
for key, val in inputs.items():
1824-
if isinstance(val, torch.Tensor):
1825-
chunk_inputs[key] = val[start_idx:end_idx]
1826-
else:
1827-
chunk_inputs[key] = val
1828-
else:
1829-
# for MLLM, re-encode to get mm-related inputs
1830-
origin_data = inputs['_origin_data'][start_idx:end_idx]
1831-
template = self.template
1832-
with self._template_context(template):
1833-
chunk_inputs = [template.encode(data) for data in origin_data]
1834-
chunk_inputs = to_device(template.data_collator(chunk_inputs), self.model.device)
1835-
chunk_inputs['logits_to_keep'] = inputs['logits_to_keep']
1836-
chunk_inputs.pop('labels', None)
1837-
return chunk_inputs
1838-
18391814
batch_size = inputs['input_ids'].shape[0]
18401815
mode = 'train' if self.model.training else 'eval'
18411816
chunk_size = self.args.per_device_train_batch_size if mode == 'train' else self.args.per_device_eval_batch_size
@@ -1855,7 +1830,7 @@ def get_chunked_inputs(inputs, start_idx, end_idx):
18551830
end_idx = min(start_idx + new_chunk_size, batch_size)
18561831

18571832
if start_idx < end_idx:
1858-
chunk_inputs = get_chunked_inputs(inputs, start_idx, end_idx)
1833+
chunk_inputs = self.get_chunked_inputs(inputs, start_idx, end_idx)
18591834

18601835
chunk_logps, chunk_entropies = self._get_per_token_logps_and_entropies_single(
18611836
model, chunk_inputs, compute_entropy)
@@ -2542,6 +2517,14 @@ def merge_output_input_data(input_data: Dict[str, Union[torch.Tensor, Any]], out
25422517
input_data['finish_reason'] = choice.finish_reason
25432518
input_data['is_truncated'] = choice.finish_reason == 'length'
25442519

2520+
# Step 5: override multi-modal data from rollout_infos
2521+
if output.rollout_infos:
2522+
multi_modal_keys = ['images', 'videos', 'audios']
2523+
for key in multi_modal_keys:
2524+
if key in output.rollout_infos:
2525+
input_data[key] = output.rollout_infos[key]
2526+
logger.info_once(f'Overriding multi-modal data from rollout_infos for key: {key}')
2527+
25452528
return input_data
25462529

25472530
if not self.dynamic_num_samples:
@@ -2840,3 +2823,21 @@ def _get_last_indices(self, request_ids: List[str]) -> torch.Tensor:
28402823
for i, rid in enumerate(request_ids):
28412824
seen[rid] = i
28422825
return torch.tensor(list(seen.values()), dtype=torch.long, device=self.accelerator.device)
2826+
2827+
def get_chunked_inputs(self, inputs, start_idx, end_idx):
2828+
chunk_inputs = {}
2829+
# for LLM, slice the inputs
2830+
for key, val in inputs.items():
2831+
if isinstance(val, torch.Tensor):
2832+
chunk_inputs[key] = val[start_idx:end_idx]
2833+
else:
2834+
chunk_inputs[key] = val
2835+
if self.is_multimodal:
2836+
# for MLLM, re-encode to get mm-related inputs
2837+
origin_data = inputs['_origin_data'][start_idx:end_idx]
2838+
template = self.template
2839+
with self._template_context(template):
2840+
encoded_data = [template.encode(data) for data in origin_data]
2841+
chunk_inputs.update(to_device(template.data_collator(encoded_data), self.model.device))
2842+
chunk_inputs.pop('labels', None)
2843+
return chunk_inputs

0 commit comments

Comments
 (0)