@@ -155,7 +155,7 @@ def _init_env_clients(self) -> List[Any]: # Renamed and return type changed
155155 # raise ValueError(f"Task class {task_class_name} did not provide a client for port {port}.")
156156 except Exception as e :
157157 print (f" - Client { i + 1 } : Error initializing Task or getting client for port { port } : { e } " )
158- print (traceback .format_exc ()) # Print detailed traceback
158+ print (traceback .format_exc ())
159159 # Decide how to handle failure: raise error or skip? Skipping for now.
160160 # raise
161161
@@ -336,7 +336,8 @@ def _run_single_rollout(self, initial_prompt_ids: torch.Tensor, task_idx: int, c
336336 'step_rewards' : step_rewards , # List of rewards from each env.step call
337337 'reward' : final_reward , # Reward from the *last* env.step call
338338 'env_score' : final_env_score , # Final score reported by env info
339- 'turns' : turns ,
339+ 'turns' : turns , # Total number of turns executed
340+ 'valid_actions' : len ([msg for msg in trajectory if msg .get ("from" ) == "gpt" ]), # Count of agent's responses
340341 'task_idx' : task_idx ,
341342 'done' : done # Whether the episode finished naturally or via error
342343 }
@@ -422,17 +423,45 @@ def run_llm_loop(self, gen_batch: DataProto, output_dir: str = None, global_step
422423 if not valid_results :
423424 print ("[Agent.run_llm_loop] Error: No valid rollout results collected." )
424425 # Return empty DataProto but with correct structure if possible
425- return DataProto .from_dict ({
426+ empty_proto = DataProto .from_dict ({
426427 "input_ids" : torch .empty ((0 ,0 ), dtype = torch .long ),
427428 "attention_mask" : torch .empty ((0 ,0 ), dtype = torch .long ),
428429 "position_ids" : torch .empty ((0 ,0 ), dtype = torch .long ),
429430 "info_mask" : torch .empty ((0 ,0 ), dtype = torch .long ),
430431 "token_level_rewards" : torch .empty ((0 ,0 ), dtype = torch .float )
431432 })
433+ # Add necessary meta_info for downstream compute_log_prob call
434+ empty_proto .meta_info = {'micro_batch_size' : 1 }
435+ return empty_proto
432436
433437 # --- Format Results into DataProto ---
434438 processed_data = self ._convert_rollout_results_to_dataproto (valid_results , gen_batch )
435439
440+ # --- CRITICAL: Add necessary meta_info parameters for compute_log_prob ---
441+ # These parameters are required by DataParallelActor.compute_log_prob
442+ # Source values from the actor_rollout_wg config or AgentConfig
443+ log_prob_micro_batch_size = getattr (self .actor_rollout_wg , 'log_prob_micro_batch_size' , 128 )
444+ if hasattr (self .config , 'actor_rollout_ref' ) and hasattr (self .config .actor_rollout_ref , 'rollout' ):
445+ # If running within the trainer which has direct access to these configs
446+ log_prob_micro_batch_size = getattr (self .config .actor_rollout_ref .rollout , 'log_prob_micro_batch_size' , log_prob_micro_batch_size )
447+
448+ # Ensure these keys exist and have reasonable default values even if not specified in config
449+ if 'micro_batch_size' not in processed_data .meta_info :
450+ processed_data .meta_info ['micro_batch_size' ] = log_prob_micro_batch_size
451+
452+ if 'temperature' not in processed_data .meta_info :
453+ processed_data .meta_info ['temperature' ] = getattr (self .config , 'temperature' , 1.0 )
454+
455+ if 'use_dynamic_bsz' not in processed_data .meta_info :
456+ processed_data .meta_info ['use_dynamic_bsz' ] = getattr (self .config , 'log_prob_use_dynamic_bsz' , False )
457+
458+ # If dynamic batch size is used, also set max_token_len
459+ if processed_data .meta_info .get ('use_dynamic_bsz' , False ):
460+ max_token_len = getattr (self .config , 'log_prob_max_token_len_per_gpu' , 2048 )
461+ processed_data .meta_info ['max_token_len' ] = max_token_len
462+
463+ print (f"[Agent.run_llm_loop] Added log_prob parameters to meta_info: micro_batch_size={ processed_data .meta_info ['micro_batch_size' ]} , temperature={ processed_data .meta_info ['temperature' ]} , use_dynamic_bsz={ processed_data .meta_info ['use_dynamic_bsz' ]} " )
464+
436465 print (f"[Agent.run_llm_loop] Finished processing rollout results." )
437466 return processed_data
438467
@@ -454,9 +483,25 @@ def _convert_rollout_results_to_dataproto(self, results: List[Dict], original_ba
454483 batch_position_ids = []
455484 batch_info_mask = []
456485 batch_token_level_rewards = [] # Store final token-level rewards for PPO
457- batch_meta_info = defaultdict (list )
458486 batch_responses = [] # Initialize batch_responses
459487
488+ # Initialize final_meta_info by copying all items from the original_batch.meta_info
489+ # This ensures that any global metadata from the input batch is preserved.
490+ final_meta_info = {}
491+ if hasattr (original_batch , 'meta_info' ) and original_batch .meta_info :
492+ for k , v in original_batch .meta_info .items ():
493+ final_meta_info [k ] = v # Shallow copy, or deepcopy if mutable objects are a concern
494+
495+ # For collecting stats and per-rollout lists that will be converted to tensors or kept as lists
496+ per_rollout_task_idx = []
497+ per_rollout_turns_stats = []
498+ per_rollout_valid_action_stats = []
499+ per_rollout_done_flags = []
500+ per_rollout_valid_search_stats = [] # Placeholder
501+ per_rollout_rewards = [] # Last step reward for each rollout
502+ per_rollout_env_scores = [] # Final env score for each rollout
503+ per_rollout_trajectories = [] # List of trajectories
504+
460505 # Get reward allocation strategy from config
461506 reward_allocation = "last_token" # Default
462507 if self .config .algorithm_config :
@@ -478,12 +523,35 @@ def _convert_rollout_results_to_dataproto(self, results: List[Dict], original_ba
478523
479524 turns = result_dict .get ('turns' , 0 )
480525 task_idx = result_dict .get ('task_idx' , - 1 )
481-
482- # Get the original batch index
526+ valid_actions_count = result_dict .get ('valid_actions' , 0 )
527+ done_flag = result_dict .get ('done' , True ) # Default to True if missing, indicating completion or error
528+ reward_val = result_dict .get ('reward' , 0.0 )
529+ env_score_val = result_dict .get ('env_score' , 0.0 )
530+ trajectory_val = result_dict .get ('trajectory' , [])
531+
532+ # Correctly append to per_rollout_ lists
533+ per_rollout_task_idx .append (task_idx )
534+ per_rollout_turns_stats .append (turns )
535+ per_rollout_valid_action_stats .append (valid_actions_count )
536+ per_rollout_done_flags .append (done_flag )
537+ per_rollout_valid_search_stats .append (0 ) # Placeholder, as search is not explicitly tracked here
538+ per_rollout_rewards .append (reward_val )
539+ per_rollout_env_scores .append (env_score_val )
540+ per_rollout_trajectories .append (trajectory_val )
541+
542+ # Get the original batch index (used for trajectory processing below)
483543 original_batch_idx = original_indices_map .get (task_idx , - 1 )
484544 if original_batch_idx == - 1 :
485- print (f"[Agent._convert_rollout] Warning: Task idx { task_idx } not found in original batch. Skipping." )
486- continue
545+ print (f"[Agent._convert_rollout] Warning: Task idx { task_idx } not found in original batch. Skipping this result for trajectory processing." )
546+ # If a result can't be mapped, its trajectory-derived tensors might be misaligned.
547+ # For simplicity, we might skip creating tensor entries for it, or handle padding carefully.
548+ # However, its stats (task_idx, turns, etc.) are already appended to per_rollout_ lists.
549+ # This might lead to length mismatches if not handled carefully when creating final tensors.
550+ # A robust solution would be to filter results list upfront or ensure all task_idx are mappable.
551+ # For now, we proceed, and downstream tensor creation should handle potential Nones if any result is fully skipped.
552+ # OR, more simply, if we can't map, we might have to skip this entire result_dict earlier.
553+ # For now, let the per_rollout lists gather all data, and mismatches will be an issue at tensor conversion.
554+ pass # Original_batch_idx is used for trajectory processing, not for the stats lists directly.
487555
488556 # --- Concatenate conversation and identify agent segments ---
489557 conversation_ids_list = []
@@ -675,20 +743,25 @@ def _convert_rollout_results_to_dataproto(self, results: List[Dict], original_ba
675743 batch_responses .append (response_only_ids_padded )
676744
677745 # Add metadata
678- batch_meta_info ["task_idx" ].append (task_idx )
679- batch_meta_info ["turns_stats" ].append (turns )
680- batch_meta_info ["valid_action_stats" ].append (valid_actions )
681- batch_meta_info ["reward" ].append (result_dict .get ('reward' , 0.0 )) # Last step reward
682- batch_meta_info ["env_score" ].append (result_dict .get ('env_score' , 0.0 )) # Final env score
683- batch_meta_info ["rollout_trajectory" ].append (trajectory )
684- # Copy relevant metadata from original_batch
685- for key , value in original_batch .meta_info .items ():
686- if key not in ['idx' , 'reward' , 'env_score' ]: # Avoid duplication
687- if isinstance (value , list ) and len (value ) > original_batch_idx :
688- batch_meta_info [key ].append (value [original_batch_idx ])
689- elif not isinstance (value , list ): # Keep non-list metadata
690- if task_idx == original_indices [0 ]: # Add only once per batch
691- batch_meta_info [key ] = value
746+ if "task_idx" not in final_meta_info :
747+ final_meta_info ["task_idx" ] = []
748+ if "turns_stats" not in final_meta_info :
749+ final_meta_info ["turns_stats" ] = []
750+ if "valid_action_stats" not in final_meta_info :
751+ final_meta_info ["valid_action_stats" ] = []
752+ if "reward" not in final_meta_info :
753+ final_meta_info ["reward" ] = []
754+ if "env_score" not in final_meta_info :
755+ final_meta_info ["env_score" ] = []
756+ if "rollout_trajectory" not in final_meta_info :
757+ final_meta_info ["rollout_trajectory" ] = []
758+
759+ final_meta_info ["task_idx" ].append (task_idx )
760+ final_meta_info ["turns_stats" ].append (turns )
761+ final_meta_info ["valid_action_stats" ].append (valid_actions_count )
762+ final_meta_info ["reward" ].append (reward_val )
763+ final_meta_info ["env_score" ].append (env_score_val )
764+ final_meta_info ["rollout_trajectory" ].append (trajectory_val )
692765
693766 # --- Stack Tensors ---
694767 if not batch_input_ids :
@@ -707,36 +780,38 @@ def _convert_rollout_results_to_dataproto(self, results: List[Dict], original_ba
707780 "input_ids" : torch .cat (batch_input_ids , dim = 0 ),
708781 "attention_mask" : torch .cat (batch_attention_mask , dim = 0 ),
709782 "position_ids" : torch .cat (batch_position_ids , dim = 0 ),
710- "info_mask" : torch .cat (batch_info_mask , dim = 0 ),
783+ "info_mask" : torch .cat (batch_info_mask , dim = 0 ), # This is the equivalent of responses_with_info_mask related construction
711784 "token_level_rewards" : torch .cat (batch_token_level_rewards , dim = 0 ),
712785 "responses" : torch .cat (batch_responses , dim = 0 )
713786 }
714787
715788 # Create DataProto and add metadata
716789 data_proto = DataProto .from_dict (final_batch )
717- for key , value in batch_meta_info .items ():
718- try :
719- if isinstance (value , list ) and all (isinstance (item , (int , float )) for item in value ):
720- data_proto .meta_info [key ] = torch .tensor (value )
721- # Handle numpy arrays if they appear
722- elif isinstance (value , np .ndarray ):
723- data_proto .meta_info [key ] = torch .from_numpy (value )
724- else :
725- # Keep as list for non-numeric types (like trajectories)
726- data_proto .meta_info [key ] = value
727- except (ValueError , TypeError , RuntimeError ) as e :
728- # Fallback: keep as list if tensor conversion fails
729- print (f"[Agent._convert_rollout] Warning: Could not convert metadata '{ key } ' to tensor: { e } . Keeping as list." )
730- data_proto .meta_info [key ] = value
731-
732- # Explicitly add final env scores as a tensor if possible
733- if "env_score" in batch_meta_info :
734- try :
735- data_proto .meta_info ["env_scores" ] = torch .tensor (batch_meta_info ["env_score" ], dtype = torch .float32 )
736- except (ValueError , TypeError ):
737- # Fallback case
738- print ("[Agent._convert_rollout] Could not convert env_scores to tensor, keeping original list." )
739- data_proto .meta_info ["env_scores" ] = batch_meta_info ["env_score" ]
790+
791+ # Add collected statistics and per-rollout lists to final_meta_info, converting to tensors where appropriate
792+ # These will overwrite any keys with the same name inherited from original_batch.meta_info if they were lists per sample.
793+ final_meta_info ['task_idx' ] = torch .tensor (per_rollout_task_idx , dtype = torch .long )
794+ final_meta_info ['turns_stats' ] = torch .tensor (per_rollout_turns_stats , dtype = torch .long )
795+ final_meta_info ['valid_action_stats' ] = torch .tensor (per_rollout_valid_action_stats , dtype = torch .long )
796+ final_meta_info ['valid_search_stats' ] = torch .tensor (per_rollout_valid_search_stats , dtype = torch .long ) # Will be zeros
797+ final_meta_info ['active_mask' ] = torch .tensor ([not done for done in per_rollout_done_flags ], dtype = torch .bool )
798+ final_meta_info ['reward' ] = torch .tensor (per_rollout_rewards , dtype = torch .float32 ) # Individual rewards per rollout
799+ final_meta_info ['env_score' ] = torch .tensor (per_rollout_env_scores , dtype = torch .float32 ) # Final scores per rollout
800+ final_meta_info ['rollout_trajectory' ] = per_rollout_trajectories # Keep as list of lists/dicts
801+
802+ # If 'idx' was in original_batch.meta_info and was a tensor, it might have been copied directly.
803+ # If it needs to be specifically task_idx, the above 'task_idx' tensor is now authoritative for the samples in this batch.
804+ # We can choose to remove the original 'idx' if it causes confusion or ensure it's compatible.
805+ # For now, the new 'task_idx' list converted to a tensor becomes the primary index for these processed samples.
806+ if 'idx' in final_meta_info and not torch .is_tensor (final_meta_info ['idx' ]):
807+ # If original idx was not a tensor or needs to be sample-specific for this processed batch
808+ print (f"[Agent._convert_rollout] Replacing original 'idx' with new 'task_idx' tensor." )
809+ final_meta_info ['idx' ] = final_meta_info ['task_idx' ]
810+ elif 'idx' not in final_meta_info :
811+ final_meta_info ['idx' ] = final_meta_info ['task_idx' ]
812+
813+ # Assign the fully constructed final_meta_info to the DataProto object
814+ data_proto .meta_info = final_meta_info
740815
741816 print (f"[Agent._convert_rollout] Final batch shapes: input_ids={ final_batch ['input_ids' ].shape } , token_level_rewards={ final_batch ['token_level_rewards' ].shape } , responses={ final_batch ['responses' ].shape } " )
742817 return data_proto
0 commit comments