@@ -1592,18 +1592,40 @@ def is_async_generate_eval_rollout_done(self):
1592
1592
def is_async_generate_train_rollout_done (self ):
1593
1593
return not self .train_queue .empty ()
1594
1594
1595
- def inputs_to_rolloutrequest (self , inputs : InputsType ) -> RolloutInferRequest :
1595
+ def inputs_to_rolloutrequest (self , inputs : InputsType ) -> List [RolloutInferRequest ]:
1596
+ """Convert a list of inputs to a list of RolloutInferRequest objects
1596
1597
1598
+ If the input contains a 'data_dict' key, it will be used as the base for the new data_dict.
1599
+ For other keys, if they overlap with keys in data_dict, the values from data_dict will be used.
1600
+ Non-overlapping keys will be added to data_dict.
1601
+
1602
+ Args:
1603
+ inputs: List of input dictionaries
1604
+
1605
+ Returns:
1606
+ List of RolloutInferRequest objects
1607
+ """
1597
1608
request_keys = ['messages' , 'images' , 'audios' , 'videos' , 'tools' , 'objects' ]
1598
- infer_requests = [
1599
- RolloutInferRequest (
1600
- ** {
1601
- ** {k : request [k ]
1602
- for k in request_keys if k in request }, 'data_dict' :
1603
- {k : request [k ]
1604
- for k in request if k not in request_keys }
1605
- }) for request in inputs
1606
- ]
1609
+ infer_requests = []
1610
+
1611
+ for request in inputs :
1612
+ # Get the base data_dict if it exists in the input
1613
+ base_data_dict = {}
1614
+ if 'data_dict' in request :
1615
+ if isinstance (request ['data_dict' ], dict ):
1616
+ base_data_dict = request ['data_dict' ]
1617
+ else :
1618
+ raise ValueError ('data_dict exists but is not a dictionary' )
1619
+
1620
+ # Collect all non-request_keys items as extra fields
1621
+ extra_data = {k : request [k ] for k in request if k not in request_keys and k != 'data_dict' }
1622
+
1623
+ # Merge the data_dict, keeping keys from base_data_dict as priority
1624
+ final_data_dict = {** extra_data , ** base_data_dict }
1625
+
1626
+ # Create RolloutInferRequest instance
1627
+ req_args = {k : request [k ] for k in request_keys if k in request }
1628
+ infer_requests .append (RolloutInferRequest (** req_args , data_dict = final_data_dict ))
1607
1629
1608
1630
return infer_requests
1609
1631
0 commit comments