@@ -1592,18 +1592,40 @@ def is_async_generate_eval_rollout_done(self):
15921592 def is_async_generate_train_rollout_done (self ):
15931593 return not self .train_queue .empty ()
15941594
1595- def inputs_to_rolloutrequest (self , inputs : InputsType ) -> RolloutInferRequest :
1595+ def inputs_to_rolloutrequest (self , inputs : InputsType ) -> List [RolloutInferRequest ]:
1596+ """Convert a list of inputs to a list of RolloutInferRequest objects
15961597
1598+ If the input contains a 'data_dict' key, it will be used as the base for the new data_dict.
1599+ For other keys, if they overlap with keys in data_dict, the values from data_dict will be used.
1600+ Non-overlapping keys will be added to data_dict.
1601+
1602+ Args:
1603+ inputs: List of input dictionaries
1604+
1605+ Returns:
1606+ List of RolloutInferRequest objects
1607+ """
15971608 request_keys = ['messages' , 'images' , 'audios' , 'videos' , 'tools' , 'objects' ]
1598- infer_requests = [
1599- RolloutInferRequest (
1600- ** {
1601- ** {k : request [k ]
1602- for k in request_keys if k in request }, 'data_dict' :
1603- {k : request [k ]
1604- for k in request if k not in request_keys }
1605- }) for request in inputs
1606- ]
1609+ infer_requests = []
1610+
1611+ for request in inputs :
1612+ # Get the base data_dict if it exists in the input
1613+ base_data_dict = {}
1614+ if 'data_dict' in request :
1615+ if isinstance (request ['data_dict' ], dict ):
1616+ base_data_dict = request ['data_dict' ]
1617+ else :
1618+ raise ValueError ('data_dict exists but is not a dictionary' )
1619+
1620+ # Collect all non-request_keys items as extra fields
1621+ extra_data = {k : request [k ] for k in request if k not in request_keys and k != 'data_dict' }
1622+
1623+ # Merge the data_dict, keeping keys from base_data_dict as priority
1624+ final_data_dict = {** extra_data , ** base_data_dict }
1625+
1626+ # Create RolloutInferRequest instance
1627+ req_args = {k : request [k ] for k in request_keys if k in request }
1628+ infer_requests .append (RolloutInferRequest (** req_args , data_dict = final_data_dict ))
16071629
16081630 return infer_requests
16091631
0 commit comments