Skip to content

Commit 13253e9

Browse files
hjh0119Jintao-Huang
authored andcommitted
[bugfix] fix grpo mixed data training (#6269)
1 parent f27bf3a commit 13253e9

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

swift/trainers/rlhf_trainer/grpo_trainer.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -883,8 +883,6 @@ def _generate_and_score_completions(self, inputs: DataType) -> DataType:
883883
# NOTE: every key you register must appear in ALL rollout outputs
884884
# to avoid potential communication / synchronization issues
885885
metrics_for_logs_to_gather = {}
886-
if all('images' in data and data['images'] is not None for data in inputs):
887-
metrics_for_logs_to_gather['image'] = [inp['images'] for inp in inputs]
888886

889887
if all('solution' in inp for inp in inputs):
890888
metrics_for_logs_to_gather['solution'] = [inp['solution'] for inp in inputs]
@@ -2519,13 +2517,16 @@ def _process_image_data(image_data: Union[dict, str]) -> str:
25192517

25202518
for data in inputs:
25212519
# Extract required metadata fields
2522-
request_data = {key: data[key] for key in REQUEST_METADATA_FIELDS if key in data}
2520+
request_data = {key: data[key] for key in REQUEST_METADATA_FIELDS if key in data and data[key] is not None}
25232521
if 'uuid' not in request_data:
25242522
request_data['uuid'] = data['request_id'] # Use unique request_id for vLLM
25252523
# Preserve additional fields for multi-turn async scenarios
25262524
if self.args.vllm_server_pass_dataset:
25272525
# data_dict is already concatenated inside async engine
2528-
extra_fields = {k: v for k, v in data.items() if k not in REQUEST_METADATA_FIELDS}
2526+
extra_fields = {
2527+
k: v
2528+
for k, v in data.items() if k not in REQUEST_METADATA_FIELDS and data[k] is not None
2529+
}
25292530
if extra_fields:
25302531
request_data['data_dict'] = extra_fields
25312532
elif self.multi_turn_scheduler:
@@ -2537,7 +2538,11 @@ def _process_image_data(image_data: Union[dict, str]) -> str:
25372538
else:
25382539
raise ValueError('data_dict exists but is not a dictionary')
25392540
# Add fields that are not in metadata fields and not 'data_dict'
2540-
extra_data = {k: v for k, v in data.items() if k not in REQUEST_METADATA_FIELDS and k != 'data_dict'}
2541+
extra_data = {
2542+
k: v
2543+
for k, v in data.items()
2544+
if k not in REQUEST_METADATA_FIELDS and k != 'data_dict' and data[k] is not None
2545+
}
25412546
# Merge additional fields and existing data_dict
25422547
final_data_dict = {**extra_data, **base_data_dict}
25432548
request_data['data_dict'] = final_data_dict if final_data_dict else {}

0 commit comments

Comments
 (0)