Skip to content

Commit 46d2744

Browse files
hrz394943230huruize
andauthored
[grpo]Fix bug when repeatedly call inputs_to_rolloutrequest (#4823)
* [grpo]Fix repeatedly call inputs_to_rolloutrequest leads to nested data_dict during multi-turn training. * get lint pass --------- Co-authored-by: huruize <[email protected]>
1 parent bdbbc71 commit 46d2744

File tree

1 file changed

+32
-10
lines changed

1 file changed

+32
-10
lines changed

swift/trainers/rlhf_trainer/grpo_trainer.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1592,18 +1592,40 @@ def is_async_generate_eval_rollout_done(self):
15921592
def is_async_generate_train_rollout_done(self):
15931593
return not self.train_queue.empty()
15941594

1595-
def inputs_to_rolloutrequest(self, inputs: InputsType) -> RolloutInferRequest:
1595+
def inputs_to_rolloutrequest(self, inputs: InputsType) -> List[RolloutInferRequest]:
1596+
"""Convert a list of inputs to a list of RolloutInferRequest objects
15961597
1598+
If the input contains a 'data_dict' key, it will be used as the base for the new data_dict.
1599+
For other keys, if they overlap with keys in data_dict, the values from data_dict will be used.
1600+
Non-overlapping keys will be added to data_dict.
1601+
1602+
Args:
1603+
inputs: List of input dictionaries
1604+
1605+
Returns:
1606+
List of RolloutInferRequest objects
1607+
"""
15971608
request_keys = ['messages', 'images', 'audios', 'videos', 'tools', 'objects']
1598-
infer_requests = [
1599-
RolloutInferRequest(
1600-
**{
1601-
**{k: request[k]
1602-
for k in request_keys if k in request}, 'data_dict':
1603-
{k: request[k]
1604-
for k in request if k not in request_keys}
1605-
}) for request in inputs
1606-
]
1609+
infer_requests = []
1610+
1611+
for request in inputs:
1612+
# Get the base data_dict if it exists in the input
1613+
base_data_dict = {}
1614+
if 'data_dict' in request:
1615+
if isinstance(request['data_dict'], dict):
1616+
base_data_dict = request['data_dict']
1617+
else:
1618+
raise ValueError('data_dict exists but is not a dictionary')
1619+
1620+
# Collect all non-request_keys items as extra fields
1621+
extra_data = {k: request[k] for k in request if k not in request_keys and k != 'data_dict'}
1622+
1623+
# Merge the data_dict, keeping keys from base_data_dict as priority
1624+
final_data_dict = {**extra_data, **base_data_dict}
1625+
1626+
# Create RolloutInferRequest instance
1627+
req_args = {k: request[k] for k in request_keys if k in request}
1628+
infer_requests.append(RolloutInferRequest(**req_args, data_dict=final_data_dict))
16071629

16081630
return infer_requests
16091631

0 commit comments

Comments
 (0)