@@ -1094,8 +1094,20 @@ def _apply_chat_template_to_messages_list(self, messages_list: InputsType):
10941094 InferRequest .remove_response (messages )
10951095 template_inputs , _ = StdTemplateInputs .from_dict ({'messages' : messages })
10961096 res_context_list , _ , _ = self .template ._swift_encode (template_inputs )
1097- prompts_text .append ('' .join (elem for elem in res_context_list if isinstance (elem , str )))
10981097
1098+ # check the type and convert
1099+ processed_context = []
1100+ for context in res_context_list :
1101+ if isinstance (context , str ):
1102+ processed_context .append (context )
1103+ elif isinstance (context , list ) and all (isinstance (x , int ) for x in context ):
1104+ # decode the token ID to text
1105+ decoded_text = self .template .tokenizer .decode (context )
1106+ processed_context .append (decoded_text )
1107+ else :
1108+ # other type value ,just add to process_context
1109+ processed_context .append (str (context ))
1110+ prompts_text .append ('' .join (processed_context ))
10991111 return prompts_text
11001112
11011113 @profiling_decorator
@@ -1421,7 +1433,7 @@ def _process_infer_requests_images(self, infer_requests: InputsType):
14211433 return
14221434
14231435 def old_policy (self ):
1424- return self .num_iterations > 1 or self .args .steps_per_generation > self .args .gradient_accumulation_steps
1436+ return self .num_iterations > 1 or self .args .gradient_accumulation_steps % self .args .steps_per_generation != 0
14251437
14261438 @property
14271439 def _queue (self ):
@@ -1580,18 +1592,40 @@ def is_async_generate_eval_rollout_done(self):
15801592 def is_async_generate_train_rollout_done (self ):
15811593 return not self .train_queue .empty ()
15821594
1583- def inputs_to_rolloutrequest (self , inputs : InputsType ) -> RolloutInferRequest :
1595+ def inputs_to_rolloutrequest (self , inputs : InputsType ) -> List [RolloutInferRequest ]:
1596+ """Convert a list of inputs to a list of RolloutInferRequest objects
1597+
1598+ If the input contains a 'data_dict' key, it will be used as the base for the new data_dict.
1599+ For other keys, if they overlap with keys in data_dict, the values from data_dict will be used.
1600+ Non-overlapping keys will be added to data_dict.
1601+
1602+ Args:
1603+ inputs: List of input dictionaries
15841604
1605+ Returns:
1606+ List of RolloutInferRequest objects
1607+ """
15851608 request_keys = ['messages' , 'images' , 'audios' , 'videos' , 'tools' , 'objects' ]
1586- infer_requests = [
1587- RolloutInferRequest (
1588- ** {
1589- ** {k : request [k ]
1590- for k in request_keys if k in request }, 'data_dict' :
1591- {k : request [k ]
1592- for k in request if k not in request_keys }
1593- }) for request in inputs
1594- ]
1609+ infer_requests = []
1610+
1611+ for request in inputs :
1612+ # Get the base data_dict if it exists in the input
1613+ base_data_dict = {}
1614+ if 'data_dict' in request :
1615+ if isinstance (request ['data_dict' ], dict ):
1616+ base_data_dict = request ['data_dict' ]
1617+ else :
1618+ raise ValueError ('data_dict exists but is not a dictionary' )
1619+
1620+ # Collect all non-request_keys items as extra fields
1621+ extra_data = {k : request [k ] for k in request if k not in request_keys and k != 'data_dict' }
1622+
1623+ # Merge the data_dict, keeping keys from base_data_dict as priority
1624+ final_data_dict = {** extra_data , ** base_data_dict }
1625+
1626+ # Create RolloutInferRequest instance
1627+ req_args = {k : request [k ] for k in request_keys if k in request }
1628+ infer_requests .append (RolloutInferRequest (** req_args , data_dict = final_data_dict ))
15951629
15961630 return infer_requests
15971631
0 commit comments