diff --git a/openmanus_rl/agentgym/agentenv-webshop/environment.yml b/openmanus_rl/agentgym/agentenv-webshop/environment.yml index 2d69f271..a5859492 100644 --- a/openmanus_rl/agentgym/agentenv-webshop/environment.yml +++ b/openmanus_rl/agentgym/agentenv-webshop/environment.yml @@ -3,6 +3,6 @@ channels: - conda-forge - defaults dependencies: - - python=3.8.13=ha86cf86_0_cpython + - python=3.8.13 - faiss-cpu=1.7.4 - - openjdk=11.0.21=h4260e57_0 + - openjdk=11.0.21 diff --git a/openmanus_rl/llm_agent/openmanus.py b/openmanus_rl/llm_agent/openmanus.py index 1315315f..96e3b6c5 100644 --- a/openmanus_rl/llm_agent/openmanus.py +++ b/openmanus_rl/llm_agent/openmanus.py @@ -155,7 +155,7 @@ def _init_env_clients(self) -> List[Any]: # Renamed and return type changed # raise ValueError(f"Task class {task_class_name} did not provide a client for port {port}.") except Exception as e: print(f" - Client {i+1}: Error initializing Task or getting client for port {port}: {e}") - print(traceback.format_exc()) # Print detailed traceback + print(traceback.format_exc()) # Decide how to handle failure: raise error or skip? Skipping for now. # raise @@ -336,7 +336,8 @@ def _run_single_rollout(self, initial_prompt_ids: torch.Tensor, task_idx: int, c 'step_rewards': step_rewards, # List of rewards from each env.step call 'reward': final_reward, # Reward from the *last* env.step call 'env_score': final_env_score, # Final score reported by env info - 'turns': turns, + 'turns': turns, # Total number of turns executed + 'valid_actions': len([msg for msg in trajectory if msg.get("from") == "gpt"]), # Count of agent's responses 'task_idx': task_idx, 'done': done # Whether the episode finished naturally or via error } @@ -422,17 +423,45 @@ def run_llm_loop(self, gen_batch: DataProto, output_dir: str = None, global_step if not valid_results: print("[Agent.run_llm_loop] Error: No valid rollout results collected.") # Return empty DataProto but with correct structure if possible - return DataProto.from_dict({ + empty_proto = DataProto.from_dict({ "input_ids": torch.empty((0,0), dtype=torch.long), "attention_mask": torch.empty((0,0), dtype=torch.long), "position_ids": torch.empty((0,0), dtype=torch.long), "info_mask": torch.empty((0,0), dtype=torch.long), "token_level_rewards": torch.empty((0,0), dtype=torch.float) }) + # Add necessary meta_info for downstream compute_log_prob call + empty_proto.meta_info = {'micro_batch_size': 1} + return empty_proto # --- Format Results into DataProto --- processed_data = self._convert_rollout_results_to_dataproto(valid_results, gen_batch) + # --- CRITICAL: Add necessary meta_info parameters for compute_log_prob --- + # These parameters are required by DataParallelActor.compute_log_prob + # Source values from the actor_rollout_wg config or AgentConfig + log_prob_micro_batch_size = getattr(self.actor_rollout_wg, 'log_prob_micro_batch_size', 128) + if hasattr(self.config, 'actor_rollout_ref') and hasattr(self.config.actor_rollout_ref, 'rollout'): + # If running within the trainer which has direct access to these configs + log_prob_micro_batch_size = getattr(self.config.actor_rollout_ref.rollout, 'log_prob_micro_batch_size', log_prob_micro_batch_size) + + # Ensure these keys exist and have reasonable default values even if not specified in config + if 'micro_batch_size' not in processed_data.meta_info: + processed_data.meta_info['micro_batch_size'] = log_prob_micro_batch_size + + if 'temperature' not in processed_data.meta_info: + processed_data.meta_info['temperature'] = getattr(self.config, 'temperature', 1.0) + + if 'use_dynamic_bsz' not in processed_data.meta_info: + processed_data.meta_info['use_dynamic_bsz'] = getattr(self.config, 'log_prob_use_dynamic_bsz', False) + + # If dynamic batch size is used, also set max_token_len + if processed_data.meta_info.get('use_dynamic_bsz', False): + max_token_len = getattr(self.config, 'log_prob_max_token_len_per_gpu', 2048) + processed_data.meta_info['max_token_len'] = max_token_len + + print(f"[Agent.run_llm_loop] Added log_prob parameters to meta_info: micro_batch_size={processed_data.meta_info['micro_batch_size']}, temperature={processed_data.meta_info['temperature']}, use_dynamic_bsz={processed_data.meta_info['use_dynamic_bsz']}") + print(f"[Agent.run_llm_loop] Finished processing rollout results.") return processed_data @@ -454,9 +483,25 @@ def _convert_rollout_results_to_dataproto(self, results: List[Dict], original_ba batch_position_ids = [] batch_info_mask = [] batch_token_level_rewards = [] # Store final token-level rewards for PPO - batch_meta_info = defaultdict(list) batch_responses = [] # Initialize batch_responses + # Initialize final_meta_info by copying all items from the original_batch.meta_info + # This ensures that any global metadata from the input batch is preserved. + final_meta_info = {} + if hasattr(original_batch, 'meta_info') and original_batch.meta_info: + for k, v in original_batch.meta_info.items(): + final_meta_info[k] = v # Shallow copy, or deepcopy if mutable objects are a concern + + # For collecting stats and per-rollout lists that will be converted to tensors or kept as lists + per_rollout_task_idx = [] + per_rollout_turns_stats = [] + per_rollout_valid_action_stats = [] + per_rollout_done_flags = [] + per_rollout_valid_search_stats = [] # Placeholder + per_rollout_rewards = [] # Last step reward for each rollout + per_rollout_env_scores = [] # Final env score for each rollout + per_rollout_trajectories = [] # List of trajectories + # Get reward allocation strategy from config reward_allocation = "last_token" # Default if self.config.algorithm_config: @@ -478,12 +523,35 @@ def _convert_rollout_results_to_dataproto(self, results: List[Dict], original_ba turns = result_dict.get('turns', 0) task_idx = result_dict.get('task_idx', -1) - - # Get the original batch index + valid_actions_count = result_dict.get('valid_actions', 0) + done_flag = result_dict.get('done', True) # Default to True if missing, indicating completion or error + reward_val = result_dict.get('reward', 0.0) + env_score_val = result_dict.get('env_score', 0.0) + trajectory_val = result_dict.get('trajectory', []) + + # Correctly append to per_rollout_ lists + per_rollout_task_idx.append(task_idx) + per_rollout_turns_stats.append(turns) + per_rollout_valid_action_stats.append(valid_actions_count) + per_rollout_done_flags.append(done_flag) + per_rollout_valid_search_stats.append(0) # Placeholder, as search is not explicitly tracked here + per_rollout_rewards.append(reward_val) + per_rollout_env_scores.append(env_score_val) + per_rollout_trajectories.append(trajectory_val) + + # Get the original batch index (used for trajectory processing below) original_batch_idx = original_indices_map.get(task_idx, -1) if original_batch_idx == -1: - print(f"[Agent._convert_rollout] Warning: Task idx {task_idx} not found in original batch. Skipping.") - continue + print(f"[Agent._convert_rollout] Warning: Task idx {task_idx} not found in original batch. Skipping this result for trajectory processing.") + # If a result can't be mapped, its trajectory-derived tensors might be misaligned. + # For simplicity, we might skip creating tensor entries for it, or handle padding carefully. + # However, its stats (task_idx, turns, etc.) are already appended to per_rollout_ lists. + # This might lead to length mismatches if not handled carefully when creating final tensors. + # A robust solution would be to filter results list upfront or ensure all task_idx are mappable. + # For now, we proceed, and downstream tensor creation should handle potential Nones if any result is fully skipped. + # OR, more simply, if we can't map, we might have to skip this entire result_dict earlier. + # For now, let the per_rollout lists gather all data, and mismatches will be an issue at tensor conversion. + pass # Original_batch_idx is used for trajectory processing, not for the stats lists directly. # --- Concatenate conversation and identify agent segments --- conversation_ids_list = [] @@ -675,20 +743,25 @@ def _convert_rollout_results_to_dataproto(self, results: List[Dict], original_ba batch_responses.append(response_only_ids_padded) # Add metadata - batch_meta_info["task_idx"].append(task_idx) - batch_meta_info["turns_stats"].append(turns) - batch_meta_info["valid_action_stats"].append(valid_actions) - batch_meta_info["reward"].append(result_dict.get('reward', 0.0)) # Last step reward - batch_meta_info["env_score"].append(result_dict.get('env_score', 0.0)) # Final env score - batch_meta_info["rollout_trajectory"].append(trajectory) - # Copy relevant metadata from original_batch - for key, value in original_batch.meta_info.items(): - if key not in ['idx', 'reward', 'env_score']: # Avoid duplication - if isinstance(value, list) and len(value) > original_batch_idx: - batch_meta_info[key].append(value[original_batch_idx]) - elif not isinstance(value, list): # Keep non-list metadata - if task_idx == original_indices[0]: # Add only once per batch - batch_meta_info[key] = value + if "task_idx" not in final_meta_info: + final_meta_info["task_idx"] = [] + if "turns_stats" not in final_meta_info: + final_meta_info["turns_stats"] = [] + if "valid_action_stats" not in final_meta_info: + final_meta_info["valid_action_stats"] = [] + if "reward" not in final_meta_info: + final_meta_info["reward"] = [] + if "env_score" not in final_meta_info: + final_meta_info["env_score"] = [] + if "rollout_trajectory" not in final_meta_info: + final_meta_info["rollout_trajectory"] = [] + + final_meta_info["task_idx"].append(task_idx) + final_meta_info["turns_stats"].append(turns) + final_meta_info["valid_action_stats"].append(valid_actions_count) + final_meta_info["reward"].append(reward_val) + final_meta_info["env_score"].append(env_score_val) + final_meta_info["rollout_trajectory"].append(trajectory_val) # --- Stack Tensors --- if not batch_input_ids: @@ -707,36 +780,38 @@ def _convert_rollout_results_to_dataproto(self, results: List[Dict], original_ba "input_ids": torch.cat(batch_input_ids, dim=0), "attention_mask": torch.cat(batch_attention_mask, dim=0), "position_ids": torch.cat(batch_position_ids, dim=0), - "info_mask": torch.cat(batch_info_mask, dim=0), + "info_mask": torch.cat(batch_info_mask, dim=0), # This is the equivalent of responses_with_info_mask related construction "token_level_rewards": torch.cat(batch_token_level_rewards, dim=0), "responses": torch.cat(batch_responses, dim=0) } # Create DataProto and add metadata data_proto = DataProto.from_dict(final_batch) - for key, value in batch_meta_info.items(): - try: - if isinstance(value, list) and all(isinstance(item, (int, float)) for item in value): - data_proto.meta_info[key] = torch.tensor(value) - # Handle numpy arrays if they appear - elif isinstance(value, np.ndarray): - data_proto.meta_info[key] = torch.from_numpy(value) - else: - # Keep as list for non-numeric types (like trajectories) - data_proto.meta_info[key] = value - except (ValueError, TypeError, RuntimeError) as e: - # Fallback: keep as list if tensor conversion fails - print(f"[Agent._convert_rollout] Warning: Could not convert metadata '{key}' to tensor: {e}. Keeping as list.") - data_proto.meta_info[key] = value - - # Explicitly add final env scores as a tensor if possible - if "env_score" in batch_meta_info: - try: - data_proto.meta_info["env_scores"] = torch.tensor(batch_meta_info["env_score"], dtype=torch.float32) - except (ValueError, TypeError): - # Fallback case - print("[Agent._convert_rollout] Could not convert env_scores to tensor, keeping original list.") - data_proto.meta_info["env_scores"] = batch_meta_info["env_score"] + + # Add collected statistics and per-rollout lists to final_meta_info, converting to tensors where appropriate + # These will overwrite any keys with the same name inherited from original_batch.meta_info if they were lists per sample. + final_meta_info['task_idx'] = torch.tensor(per_rollout_task_idx, dtype=torch.long) + final_meta_info['turns_stats'] = torch.tensor(per_rollout_turns_stats, dtype=torch.long) + final_meta_info['valid_action_stats'] = torch.tensor(per_rollout_valid_action_stats, dtype=torch.long) + final_meta_info['valid_search_stats'] = torch.tensor(per_rollout_valid_search_stats, dtype=torch.long) # Will be zeros + final_meta_info['active_mask'] = torch.tensor([not done for done in per_rollout_done_flags], dtype=torch.bool) + final_meta_info['reward'] = torch.tensor(per_rollout_rewards, dtype=torch.float32) # Individual rewards per rollout + final_meta_info['env_score'] = torch.tensor(per_rollout_env_scores, dtype=torch.float32) # Final scores per rollout + final_meta_info['rollout_trajectory'] = per_rollout_trajectories # Keep as list of lists/dicts + + # If 'idx' was in original_batch.meta_info and was a tensor, it might have been copied directly. + # If it needs to be specifically task_idx, the above 'task_idx' tensor is now authoritative for the samples in this batch. + # We can choose to remove the original 'idx' if it causes confusion or ensure it's compatible. + # For now, the new 'task_idx' list converted to a tensor becomes the primary index for these processed samples. + if 'idx' in final_meta_info and not torch.is_tensor(final_meta_info['idx']): + # If original idx was not a tensor or needs to be sample-specific for this processed batch + print(f"[Agent._convert_rollout] Replacing original 'idx' with new 'task_idx' tensor.") + final_meta_info['idx'] = final_meta_info['task_idx'] + elif 'idx' not in final_meta_info: + final_meta_info['idx'] = final_meta_info['task_idx'] + + # Assign the fully constructed final_meta_info to the DataProto object + data_proto.meta_info = final_meta_info print(f"[Agent._convert_rollout] Final batch shapes: input_ids={final_batch['input_ids'].shape}, token_level_rewards={final_batch['token_level_rewards'].shape}, responses={final_batch['responses'].shape}") return data_proto diff --git a/train_ppo.sh b/train_ppo.sh index 8645e1a1..6d87211d 100644 --- a/train_ppo.sh +++ b/train_ppo.sh @@ -1,16 +1,12 @@ #!/bin/bash # --- Configuration (defaults, can be overridden via env vars) --- -export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-0,5,9} +export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-1,4,5} WAND_PROJECT=${WAND_PROJECT:-'OpenManus-rl'} -export BASE_MODEL=${BASE_MODEL:-'Qwen/Qwen2.5-3B'} +export BASE_MODEL=${BASE_MODEL:-'../model/Qwen2.5-3B'} AGENTGYM_HOST=${AGENTGYM_HOST:-'0.0.0.0'} # Default to 0.0.0.0 for external access AGENTGYM_SQL_BIRD_PATH=${AGENTGYM_SQL_BIRD_PATH:-} # Used only for sqlgym -export NCCL_IB_DISABLE=1 -export NCCL_P2P_DISABLE=1 export PYTHONPATH="./openmanus_rl/agentgym/agentenv:${PYTHONPATH}" -export VLLM_ATTENTION_BACKEND=XFORMERS # vllm + qwen2-7b with flash_attn has some issues - # --- Argument Parsing --- usage() { @@ -230,6 +226,7 @@ export EXPERIMENT_NAME="OpenManus-rl-ppo-${BASE_MODEL##*/}-${AGENTGYM_ENV_NAME}$ # --- Run PPO Training in Base Environment --- echo -e "\\n[Trainer] Running PPO training in base environment '$BASE_CONDA_ENV'..." +export VLLM_ATTENTION_BACKEND=${VLLM_ATTENTION_BACKEND:-XFORMERS} # Construct server base URL, adding path if needed AGENTGYM_SERVER_BASE="http://$AGENTGYM_HOST" # Base URL without port @@ -284,7 +281,7 @@ hydra_overrides=( "data.env_ports=[${AGENTGYM_PORTS_STR}]" "data.train_data_num=null" "data.val_data_num=null" - "data.train_batch_size=6" + "data.train_batch_size=3" "data.val_batch_size=3" "data.max_prompt_length=4096" "data.max_response_length=1000" @@ -297,8 +294,8 @@ hydra_overrides=( "actor_rollout_ref.model.enable_gradient_checkpointing=true" "actor_rollout_ref.model.use_remove_padding=True" "actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.95" - "actor_rollout_ref.actor.ppo_mini_batch_size=6" - "actor_rollout_ref.actor.ppo_micro_batch_size=6" + "actor_rollout_ref.actor.ppo_mini_batch_size=4" + "actor_rollout_ref.actor.ppo_micro_batch_size=4" "actor_rollout_ref.actor.fsdp_config.param_offload=true" "actor_rollout_ref.actor.fsdp_config.grad_offload=true" "actor_rollout_ref.actor.fsdp_config.optimizer_offload=true" @@ -332,7 +329,7 @@ hydra_overrides=( "trainer.default_hdfs_dir=null" "trainer.n_gpus_per_node=3" "trainer.nnodes=1" - "trainer.save_freq=100" + "trainer.save_freq=1" "trainer.test_freq=50" "trainer.project_name=$WAND_PROJECT" "trainer.experiment_name=$EXPERIMENT_NAME" diff --git a/verl/trainer/main_ppo.py b/verl/trainer/main_ppo.py index 9e41625f..ea66f97e 100644 --- a/verl/trainer/main_ppo.py +++ b/verl/trainer/main_ppo.py @@ -141,7 +141,7 @@ def main(config): env_vars_for_main_task['CUDA_VISIBLE_DEVICES'] = original_cuda_visible print(f"[main] Runtime env to be passed to main_task actor: {{'env_vars': {env_vars_for_main_task}}}") - ray.get(main_task.options(runtime_env={'env_vars': env_vars_for_main_task}).remote(config)) + ray.get(main_task.remote(config)) print("[main] main_task finished.") ray.shutdown() print("[main] Ray shutdown.") diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 06e53020..2ac996b7 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -101,62 +101,62 @@ def get_resource_pool(self, role: Role) -> RayResourcePool: def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty='kl'): - responses = data.batch['responses'] - response_length = responses.size(1) - token_level_scores = data.batch['token_level_scores'] # Shape (batch_size, total_length) - total_length = token_level_scores.size(1) - # batch_size = data.batch.batch_size[0] # Get scalar batch size from TensorDict property - - # --- FIX: Get batch size from a tensor inside batch --- - # Using data.batch.batch_size directly might fail if TensorDict is empty or inconsistent during init - # It's safer to get it from a guaranteed tensor like input_ids or attention_mask if available - # However, batch_size for kl_ctrl update needs to be scalar sum of batch sizes across ranks - # Let's rely on the TensorDict property for now, assuming it's consistent by this point. - # If this causes issues later, we might need to pass effective batch size differently. - batch_size_scalar = data.batch.batch_size[0] # Get scalar batch size for kl_ctrl.update - # --- END FIX --- - - # Get the attention mask for the full sequence - attention_mask = data.batch['attention_mask'] # Shape (batch_size, total_length) - # Extract the mask corresponding only to the response part - response_mask = attention_mask[:, -response_length:] # Shape (batch_size, response_length) - - # compute kl between ref_policy and current policy - if 'ref_log_prob' in data.batch.keys() and 'old_log_probs' in data.batch.keys(): - # Assuming old_log_probs and ref_log_prob have shape (batch_size, response_length) - kld = core_algos.kl_penalty(data.batch['old_log_probs'], data.batch['ref_log_prob'], - kl_penalty=kl_penalty) # Shape (batch_size, response_length) - kld = kld * response_mask # Apply mask, shape remains (batch_size, response_length) + responses = data.batch['responses'] # Shape (B, L_resp) + response_length = responses.size(1) # L_resp + token_level_scores = data.batch['token_level_scores'] # Shape (B, L_full) + + # Assuming old_log_probs and ref_log_prob are also L_full + old_log_probs_full = data.batch.get('old_log_probs') + ref_log_prob_full = data.batch.get('ref_log_prob') + + attention_mask_full = data.batch['attention_mask'] # Shape (B, L_full) + # This mask is for the response part only + response_mask = attention_mask_full[:, -response_length:] # Shape (B, L_resp) + + beta = 0.0 + # Initialize with a tensor of correct shape and type for the case where KL is not computed. + kld_response_part_masked = torch.zeros_like(response_mask, dtype=token_level_scores.dtype, device=token_level_scores.device) + + actual_kld_for_metric = 0.0 + + if ref_log_prob_full is not None and old_log_probs_full is not None: + # Calculate KLD over the full length first + kld_full = core_algos.kl_penalty(old_log_probs_full, ref_log_prob_full, kl_penalty=kl_penalty) # Shape (B, L_full) + + # Slice KLD to the response part + kld_response_part = kld_full[:, -response_length:] # Shape (B, L_resp) + + # Apply response_mask to the sliced KLD part + kld_response_part_masked = kld_response_part * response_mask # Element-wise, shapes match beta = kl_ctrl.value - else: - beta = 0 - # kld should have the same shape as the response part it would be subtracted from - kld = torch.zeros_like(response_mask, dtype=torch.float32) # Shape (batch_size, response_length) - - # Initialize token_level_rewards as a copy of scores (prompt rewards are scores) - token_level_rewards = token_level_scores.clone() - - # --- FIX: Apply KL penalty only to the response part --- - # Extract the scores corresponding to the response tokens - response_scores = token_level_scores[:, -response_length:] # Shape (batch_size, response_length) - # Calculate the rewards for the response tokens - response_rewards = response_scores - beta * kld # Shape (batch_size, response_length) - # Place the calculated response rewards back into the full rewards tensor - # Ensure rewards are only applied where the response mask is 1 - token_level_rewards[:, -response_length:][response_mask] = response_rewards[response_mask] - # --- END FIX --- - - # Calculate current_kl based on the response part - current_kl = masked_mean(kld, mask=response_mask, axis=-1) # average over sequence - current_kl = torch.mean(current_kl, dim=0).item() + + # For KL controller update and metric, use unmasked kld_response_part with response_mask + actual_kld_for_metric = masked_mean(kld_response_part, mask=response_mask, axis=-1) + actual_kld_for_metric = torch.mean(actual_kld_for_metric, dim=0).item() + + # Initialize token_level_rewards as a clone of full-length scores + token_level_rewards_full = token_level_scores.clone() # Shape (B, L_full) + + # Slice scores to the response part + scores_response_part = token_level_scores[:, -response_length:] # Shape (B, L_resp) + + # Calculate the rewards for the response tokens by subtracting scaled KLD + # kld_response_part_masked already incorporates the response_mask for zeroing out padded tokens + actual_response_rewards = scores_response_part - beta * kld_response_part_masked # Shape (B, L_resp) + + # Place the calculated response rewards back into the correct segment of the full rewards tensor + # We view the response part of the full tensor and update it using the response_mask. + token_level_rewards_full_response_part_view = token_level_rewards_full[:, -response_length:] + token_level_rewards_full_response_part_view[response_mask] = actual_response_rewards[response_mask] + # Update KL controller - kl_ctrl.update(current_kl=current_kl, n_steps=batch_size_scalar) # Use scalar batch_size + current_batch_size = responses.shape[0] + kl_ctrl.update(current_kl=actual_kld_for_metric, n_steps=current_batch_size) - # Update the DataProto with the final token_level_rewards - data.batch['token_level_rewards'] = token_level_rewards + data.batch['token_level_rewards'] = token_level_rewards_full - metrics = {'critic/kl': current_kl, 'critic/kl_coeff': beta} + metrics = {'critic/kl': actual_kld_for_metric, 'critic/kl_coeff': beta} return data, metrics @@ -164,132 +164,55 @@ def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_repeat=1): """ Compute advantage estimates based on the specified estimator (GAE or GRPO). - Now with improved error handling and debugging. + Ensures inputs to core_algos.compute_gae_advantage_return are correctly sliced to response_length. """ - try: - # prepare response group - if adv_estimator == 'gae': - # Check if values field exists, which is required for GAE - if 'values' not in data.batch: - # CHANGE: Throw an error instead of automatically falling back to GRPO - error_msg = "'values' not found in batch, required for GAE. Please ensure critic.compute_values is called before compute_advantage." - print(f"[compute_advantage][ERROR] {error_msg}") - raise ValueError(error_msg) - # Remove the automatic fallback code below - # print(f"[compute_advantage] WARNING: 'values' not found in batch, required for GAE. Falling back to GRPO estimator.") - # Fall back to GRPO estimator which doesn't require values - # adv_estimator = 'grpo' - # print(f"[compute_advantage] Switched to estimator: {adv_estimator}") - else: - values = data.batch['values'] # Assume shape (batch_size, response_length), e.g., (4, 1000) - responses = data.batch['responses'] # Shape (batch_size, response_length), e.g., (4, 1000) - token_level_rewards = data.batch['token_level_rewards'] # Shape (batch_size, total_length), e.g., (4, 4096) - attention_mask = data.batch['attention_mask'] # Shape (batch_size, total_length), e.g., (4, 4096) - - response_length = responses.size(-1) # e.g., 1000 - - # Print shapes for debugging - print(f"[compute_advantage][GAE] Response length: {response_length}") - print(f"[compute_advantage][GAE] Values shape: {values.shape}") - print(f"[compute_advantage][GAE] Token level rewards shape: {token_level_rewards.shape}") - print(f"[compute_advantage][GAE] Attention mask shape: {attention_mask.shape}") - - # --- FIX: Extract response-only parts for GAE calculation --- - # Rewards corresponding to the response part - response_rewards = token_level_rewards[:, -response_length:] # Shape (4, 1000) - # Values corresponding to the response part (already assumed to be this shape) - # response_values = values # Shape (4, 1000) # Incorrect assumption, values is full length - # ---> FIX: Slice the values tensor to match the response length <--- - response_values = values[:, -response_length:] - # Mask corresponding to the response part - response_eos_mask = attention_mask[:, -response_length:] # Shape (4, 1000) - # --- END FIX --- - - # Call GAE with aligned tensors - advantages_response, returns_response = core_algos.compute_gae_advantage_return( - token_level_rewards=response_rewards, - values=response_values, # Pass the correctly sliced values - eos_mask=response_eos_mask, - gamma=gamma, - lam=lam - ) - # advantages_response/returns_response have shape (batch_size, response_length) - - # --- FIX: Pad advantages and returns back to the full sequence length --- - total_length = token_level_rewards.size(1) # e.g., 4096 - advantages = torch.zeros_like(token_level_rewards) - returns = torch.zeros_like(token_level_rewards) - - advantages[:, -response_length:] = advantages_response - returns[:, -response_length:] = returns_response - # Apply mask again to ensure padding remains zero - advantages = advantages * attention_mask - returns = returns * attention_mask - # --- END FIX --- - - data.batch['advantages'] = advantages # Shape (4, 4096) - data.batch['returns'] = returns # Shape (4, 4096) - # Successfully computed GAE, return here - return data - - # If we reach here, we're using GRPO or we fell back to GRPO - if adv_estimator == 'grpo': - print(f"[compute_advantage] Computing GRPO advantages...") - if 'token_level_rewards' not in data.batch: - raise KeyError("Missing 'token_level_rewards' in batch, required for GRPO advantage computation") - if 'uid' not in data.non_tensor_batch: - raise KeyError("Missing 'uid' in non_tensor_batch, required for GRPO advantage computation") - if 'responses' not in data.batch: - raise KeyError("Missing 'responses' in batch, required for GRPO advantage computation") - - token_level_rewards = data.batch['token_level_rewards'] - index = data.non_tensor_batch['uid'] - responses = data.batch['responses'] - response_length = responses.size(-1) - attention_mask = data.batch['attention_mask'] - response_mask = attention_mask[:, -response_length:] - - print(f"[compute_advantage] GRPO inputs - token_level_rewards shape: {token_level_rewards.shape}, " + - f"response_length: {response_length}, response_mask shape: {response_mask.shape}, index length: {len(index)}") - - # GRPO computation with proper response rewards - advantages, returns = core_algos.compute_grpo_outcome_advantage( - token_level_rewards=token_level_rewards[:, -response_length:], - eos_mask=response_mask, - index=index - ) - - # Verify the computation results - print(f"[compute_advantage] GRPO outputs - advantages shape: {advantages.shape}, returns shape: {returns.shape}") - - # Pad back to full sequence length - total_length = token_level_rewards.size(1) - padded_advantages = torch.zeros_like(token_level_rewards) - padded_returns = torch.zeros_like(token_level_rewards) - padded_advantages[:, -response_length:] = advantages - padded_returns[:, -response_length:] = returns - - # Apply attention mask and store results - data.batch['advantages'] = padded_advantages * attention_mask - data.batch['returns'] = padded_returns * attention_mask - - print(f"[compute_advantage] GRPO advantages/returns computed successfully") - else: - raise NotImplementedError - - # Check if the computed advantages and returns are valid - if torch.isnan(data.batch['advantages']).any() or torch.isnan(data.batch['returns']).any(): - raise ValueError(f"NaN values detected in computed advantages or returns with {adv_estimator}") - - # Return the updated DataProto - return data - - except Exception as e: - import traceback - print(f"[compute_advantage][ERROR] Failed to compute advantages with {adv_estimator}: {e}") - print(traceback.format_exc()) - raise RuntimeError(f"Advantage computation failed for {adv_estimator}: {e}") + if adv_estimator == 'gae': + values_full = data.batch['values'] # Expected Shape (B, L_full) + responses = data.batch['responses'] # Shape (B, L_resp) + response_length = responses.size(-1) # L_resp + + attention_mask_full = data.batch['attention_mask'] # Shape (B, L_full) + # This is the EoS mask for the response part + response_eos_mask = attention_mask_full[:, -response_length:] # Shape (B, L_resp) + + token_level_rewards_full = data.batch['token_level_rewards'] # Shape (B, L_full) + + # Slice values and token_level_rewards to the response part + values_response_part = values_full[:, -response_length:] # Shape (B, L_resp) + token_level_rewards_response_part = token_level_rewards_full[:, -response_length:] # Shape (B, L_resp) + + # Now all inputs to compute_gae_advantage_return are response-length + advantages, returns = core_algos.compute_gae_advantage_return( + token_level_rewards=token_level_rewards_response_part, + values=values_response_part, + eos_mask=response_eos_mask, # This is already the response-specific mask + gamma=gamma, + lam=lam + ) + # advantages and returns will have shape (B, L_resp) + data.batch['advantages'] = advantages + data.batch['returns'] = returns + elif adv_estimator == 'grpo': + token_level_rewards_full = data.batch['token_level_rewards'] + responses = data.batch['responses'] # L_resp + response_length = responses.size(-1) # L_resp + attention_mask_full = data.batch['attention_mask'] # L_full + response_eos_mask = attention_mask_full[:, -response_length:] # Shape (B, L_resp) + + token_level_rewards_response_part = token_level_rewards_full[:, -response_length:] + index = data.non_tensor_batch['uid'] + + advantages, returns = core_algos.compute_grpo_outcome_advantage( + token_level_rewards=token_level_rewards_response_part, + eos_mask=response_eos_mask, + index=index + ) + data.batch['advantages'] = advantages + data.batch['returns'] = returns + else: + raise NotImplementedError + return data def reduce_metrics(metrics: dict): for key, val in metrics.items(): @@ -560,33 +483,6 @@ def __init__( # Check CUDA availability but don't fail if not available # Instead, log detailed information for diagnostics - print("\n" + "="*60) - print("[RayPPOTrainer.__init__] CUDA Availability Check:") - import os - print(f" CUDA_VISIBLE_DEVICES = {os.environ.get('CUDA_VISIBLE_DEVICES', 'Not set')}") - - if not torch.cuda.is_available(): - print(f" WARNING: CUDA is not available in RayPPOTrainer!") - print(f" This might cause issues for GPU-intensive operations.") - print(f" Try checking if CUDA_VISIBLE_DEVICES was modified by Ray.") - # Continue but warn rather than failing - else: - # Print CUDA info for debugging - device_count = torch.cuda.device_count() - print(f" CUDA is available. Found {device_count} devices.") - for i in range(device_count): - print(f" - GPU {i}: {torch.cuda.get_device_name(i)}") - - # Additional GPU memory info if available - try: - for i in range(device_count): - free_mem, total_mem = torch.cuda.mem_get_info(i) - free_gb = free_mem / (1024**3) - total_gb = total_mem / (1024**3) - print(f" - GPU {i} Memory: {free_gb:.2f}GB free / {total_gb:.2f}GB total") - except: - print(" (GPU memory info not available)") - print("="*60 + "\n") self.tokenizer = tokenizer self.config = config @@ -779,61 +675,62 @@ def _validate(self): # --- Run Validation Loop using OpenManusAgent --- all_metrics = defaultdict(list) - for val_batch in self.val_dataloader: - # Ensure batch is on the correct device (or handled by agent) - # val_batch = val_batch.to(self.rank) # May not be needed if agent handles device placement - - # Agent's run_llm_loop returns a DataProto with results including rewards/scores - processed_batch = self.validation_agent.run_llm_loop(val_batch, self.log_dir, self.global_steps) - - # --- Extract metrics from the agent's output --- - # The reward/score should ideally be in processed_batch.meta_info - # Let's assume 'env_score' holds the final task score per item - if 'env_score' in processed_batch.meta_info: - scores = processed_batch.meta_info['env_score'] - if isinstance(scores, torch.Tensor): - scores = scores.cpu().tolist() - all_metrics['val_reward_score'].extend(scores) # Use a consistent key - all_metrics['env_score'].extend(scores) # Also log as env_score - - # Log other stats if available - if 'turns_stats' in processed_batch.meta_info: - turns = processed_batch.meta_info['turns_stats'] - if isinstance(turns, torch.Tensor): turns = turns.cpu().tolist() - all_metrics['turns_stats'].extend(turns) - - if 'valid_action_stats' in processed_batch.meta_info: - valid_actions = processed_batch.meta_info['valid_action_stats'] - if isinstance(valid_actions, torch.Tensor): valid_actions = valid_actions.cpu().tolist() - all_metrics['valid_action_stats'].extend(valid_actions) + # comment on the validation loop + # for val_batch in self.val_dataloader: + # # Ensure batch is on the correct device (or handled by agent) + # # val_batch = val_batch.to(self.rank) # May not be needed if agent handles device placement + + # # Agent's run_llm_loop returns a DataProto with results including rewards/scores + # processed_batch = self.validation_agent.run_llm_loop(val_batch, self.log_dir, self.global_steps) + + # # --- Extract metrics from the agent's output --- + # # The reward/score should ideally be in processed_batch.meta_info + # # Let's assume 'env_score' holds the final task score per item + # if 'env_score' in processed_batch.meta_info: + # scores = processed_batch.meta_info['env_score'] + # if isinstance(scores, torch.Tensor): + # scores = scores.cpu().tolist() + # all_metrics['val_reward_score'].extend(scores) # Use a consistent key + # all_metrics['env_score'].extend(scores) # Also log as env_score + + # # Log other stats if available + # if 'turns_stats' in processed_batch.meta_info: + # turns = processed_batch.meta_info['turns_stats'] + # if isinstance(turns, torch.Tensor): turns = turns.cpu().tolist() + # all_metrics['turns_stats'].extend(turns) + + # if 'valid_action_stats' in processed_batch.meta_info: + # valid_actions = processed_batch.meta_info['valid_action_stats'] + # if isinstance(valid_actions, torch.Tensor): valid_actions = valid_actions.cpu().tolist() + # all_metrics['valid_action_stats'].extend(valid_actions) # Add any other relevant metrics from the agent's output meta_info # ... # --- Optional: Save Trajectories/Visualizations --- # Make sure save_trajectory_to_output is imported - from openmanus_rl.utils.visualization import save_trajectory_to_output - - if self.logger and 'rollout_trajectory' in processed_batch.meta_info: - # Assuming save_rollout_data can handle the trajectory format - # You might need to adapt this based on the logger's interface - try: - task_indices = processed_batch.meta_info.get('task_idx', list(range(len(processed_batch)))) - if isinstance(task_indices, torch.Tensor): task_indices = task_indices.cpu().tolist() - - for idx, trajectory in enumerate(processed_batch.meta_info['rollout_trajectory']): - if idx < 5: # Limit saving to avoid excessive logging - original_task_idx = task_indices[idx] - save_trajectory_to_output( - trajectory, - output_dir=self.log_dir, - global_step=self.global_steps, - task_idx=original_task_idx, - prefix="val" - ) - except Exception as e: - print(f"[Trainer] Warning: Failed to save validation trajectory: {e}") - import traceback + # from openmanus_rl.utils.visualization import save_trajectory_to_output + + # if self.logger and 'rollout_trajectory' in processed_batch.meta_info: + # # Assuming save_rollout_data can handle the trajectory format + # # You might need to adapt this based on the logger's interface + # try: + # task_indices = processed_batch.meta_info.get('task_idx', list(range(len(processed_batch)))) + # if isinstance(task_indices, torch.Tensor): task_indices = task_indices.cpu().tolist() + + # for idx, trajectory in enumerate(processed_batch.meta_info['rollout_trajectory']): + # if idx < 5: # Limit saving to avoid excessive logging + # original_task_idx = task_indices[idx] + # save_trajectory_to_output( + # trajectory, + # output_dir=self.log_dir, + # global_step=self.global_steps, + # task_idx=original_task_idx, + # prefix="val" + # ) + # except Exception as e: + # print(f"[Trainer] Warning: Failed to save validation trajectory: {e}") + # import traceback # traceback.print_exc() # Uncomment for more details # --- Aggregate and Log Metrics --- @@ -915,269 +812,78 @@ def _validate(self): print(f"[Trainer] Warning: No standard validation metrics were aggregated.") return aggregated_metrics - def verify_worker_cuda_setup(self, worker_name, worker_group): - """Verify if worker has correctly set up CUDA devices""" - print(f"\n--- Verifying CUDA for {worker_name} --- ") - try: - worker_info = None - if hasattr(worker_group, 'get_worker_info') and callable(getattr(worker_group, 'get_worker_info')): - # Wrap remote call in try-except - try: - worker_info = ray.get(worker_group.get_worker_info.remote()) - print(f"[CUDA DEBUG] {worker_name} worker info (from group): {worker_info}") - except Exception as e_info: - print(f"[CUDA DEBUG][ERROR] Failed to get worker_info for {worker_name}: {e_info}") - - # Remotely check worker's internal CUDA status - gpu_status = None - model_device_info = None - if hasattr(worker_group, 'run_function') and callable(getattr(worker_group, 'run_function')): - # Define check functions to run remotely - def check_gpu_setup_remote(): - import torch, os - cuda_visible = os.environ.get('CUDA_VISIBLE_DEVICES', 'Not set') - is_available = torch.cuda.is_available() - count = torch.cuda.device_count() if is_available else 0 - devices = [torch.cuda.get_device_name(i) for i in range(count)] if count > 0 else [] - return { - 'pid': os.getpid(), - 'host': os.uname()[1], - 'CUDA_VISIBLE_DEVICES': cuda_visible, - 'torch.cuda.is_available': is_available, - 'torch.cuda.device_count': count, - 'device_names': devices - } - - def check_model_device_remote(worker_instance): - # Assuming the model is accessible via an attribute like 'model' or similar - # This needs adjustment based on actual worker implementation - model_attr_names = ['actor_module_fsdp', 'critic_module', 'ref_module_fsdp', 'reward_module'] - devices = {} - for attr_name in model_attr_names: - if hasattr(worker_instance, attr_name): - model = getattr(worker_instance, attr_name) - if hasattr(model, 'device'): - devices[attr_name] = str(model.device) - elif hasattr(model, 'module') and hasattr(model.module, 'device'): # Check wrapped module - devices[attr_name] = str(model.module.device) - elif hasattr(model, 'parameters'): - try: - first_param_device = next(model.parameters()).device - devices[attr_name] = str(first_param_device) - except StopIteration: - devices[attr_name] = "No parameters" - return devices if devices else "Model or device info not accessible" - - try: - # Use run_function_on_all_workers_sync or similar if available, - # otherwise run on rank 0. Adjust based on RayWorkerGroup implementation. - # Assuming run_function runs on rank 0 by default if not specified: - worker_gpu_check = worker_group.run_function.remote(check_gpu_setup_remote) - gpu_status = ray.get(worker_gpu_check) - print(f"[CUDA DEBUG] {worker_name} internal GPU status: {gpu_status}") - - # Pass 'self' to check model device on the worker instance - model_device_check = worker_group.run_function.remote(check_model_device_remote, args=[worker_group.workers[0]]) # Check on rank 0 - model_device_info = ray.get(model_device_check) - print(f"[CUDA DEBUG] {worker_name} internal model device info: {model_device_info}") - - except Exception as e_remote: - print(f"[CUDA DEBUG][ERROR] Error running remote check on {worker_name}: {e_remote}") - - else: - print(f"[CUDA DEBUG] {worker_name} does not support remote function execution for detailed checks.") - - print(f"--- Verification for {worker_name} complete --- \n") - return gpu_status, model_device_info - - except Exception as e: - print(f"[CUDA DEBUG][ERROR] Error checking {worker_name} CUDA setup: {e}") - import traceback - traceback.print_exc() - print(f"--- Verification for {worker_name} failed --- \n") - return False, None - def init_workers(self): - """Init resource pool and worker group - add GPU device checks and pass assignments""" - # Print driver's view of CUDA before starting workers (main_task context) - import os, ray - print(f"\n[Trainer.init_workers @ {os.uname()[1]}] Running in PID: {os.getpid()}") - - # Check CUDA availability but use a CPU fallback if needed - cuda_device = get_safe_device(allow_cpu_fallback=True) # This will print warnings if CUDA is not available - - print(f"[Trainer.init_workers] Using primary device: {cuda_device}") - - # Get available resources from Ray - ray_resources = ray.available_resources() - print(f"[Trainer.init_workers] Ray available resources: {ray_resources}") - ray_gpus = ray_resources.get('GPU', 0) - print(f"[Trainer.init_workers] Ray has {ray_gpus} GPUs available for allocation") - - # Configure resource pools - total_gpus_needed = 1 # Default minimum - if hasattr(self.config, 'trainer') and hasattr(self.config.trainer, 'n_gpus_per_node'): - total_gpus_needed = self.config.trainer.n_gpus_per_node - - print(f"[Trainer.init_workers] Configuring resource pools with {total_gpus_needed} GPUs per node") - - # Create the resource pool with the specified number of GPUs per node - try: - self.resource_pool_manager.create_resource_pool() - print(f"[Trainer.init_workers] Resource pools created: {list(self.resource_pool_manager.resource_pool_dict.keys())}") - except Exception as e: - print(f"[Trainer.init_workers] Error creating resource pools: {e}") - import traceback - traceback.print_exc() - raise RuntimeError(f"Failed to create resource pools: {e}") + """Init resource pool and worker group""" + self.resource_pool_manager.create_resource_pool() self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()} - # --- Map Roles to Classes and Resource Pools --- - # create actor and rollout - WITHOUT specifying ray_options + # create actor and rollout if self.hybrid_engine: resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout) - - # Create without ray_options - use original approach - actor_rollout_cls = RayClassWithInitArgs( - cls=self.role_worker_mapping[Role.ActorRollout], - config=self.config.actor_rollout_ref, - role='actor_rollout' - ) + actor_rollout_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.ActorRollout], + config=self.config.actor_rollout_ref, + role='actor_rollout') self.resource_pool_to_cls[resource_pool]['actor_rollout'] = actor_rollout_cls - print(f"[Trainer.init_workers] ActorRollout mapped to pool '{resource_pool.name_prefix}'") else: raise NotImplementedError - # create critic - WITHOUT specifying ray_options + # create critic if self.config.algorithm.adv_estimator == 'gae': resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic) - - critic_cls = RayClassWithInitArgs( - cls=self.role_worker_mapping[Role.Critic], - config=self.config.critic - ) + critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=self.config.critic) self.resource_pool_to_cls[resource_pool]['critic'] = critic_cls self.use_critic = True - print(f"[Trainer.init_workers] Critic mapped to pool '{resource_pool.name_prefix}'") + elif self.config.algorithm.adv_estimator == 'grpo': self.use_critic = False - # <<< Add log here >>> - print(f"[Trainer.init_workers] adv_estimator is '{self.config.algorithm.adv_estimator}', setting self.use_critic = False") else: raise NotImplementedError - # create reference policy if needed - WITHOUT specifying ray_options + # create reference policy if needed if self.use_reference_policy: resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy) - - ref_policy_cls = RayClassWithInitArgs( - cls=self.role_worker_mapping[Role.RefPolicy], - config=self.config.actor_rollout_ref, - role='ref' - ) + ref_policy_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RefPolicy], + config=self.config.actor_rollout_ref, + role='ref') self.resource_pool_to_cls[resource_pool]['ref'] = ref_policy_cls - print(f"[Trainer.init_workers] RefPolicy mapped to pool '{resource_pool.name_prefix}'") - # create a reward model if reward_fn is None - WITHOUT specifying ray_options + # create a reward model if reward_fn is None if self.use_rm: + # we create a RM here resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) - - rm_cls = RayClassWithInitArgs( - cls=self.role_worker_mapping[Role.RewardModel], - config=self.config.reward_model - ) + rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model) self.resource_pool_to_cls[resource_pool]['rm'] = rm_cls - print(f"[Trainer.init_workers] RewardModel mapped to pool '{resource_pool.name_prefix}'") - # ... rest of the method remains unchanged # initialize WorkerGroup + # NOTE: if you want to use a different resource pool for each role, which can support different parallel size, + # you should not use `create_colocated_worker_cls`. Instead, directly pass different resource pool to different worker groups. + # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information. all_wg = {} self.wg_dicts = [] - print("\n[Trainer.init_workers] Initializing Worker Groups...") for resource_pool, class_dict in self.resource_pool_to_cls.items(): - print(f" Initializing group for resource pool: {resource_pool.name_prefix}") worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) - # Pass resource requests (like num_gpus) defined in RayClassWithInitArgs to the group wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls) - print(f" Spawning workers for group {resource_pool.name_prefix}...") - try: - spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) - all_wg.update(spawn_wg) - self.wg_dicts.append(wg_dict) - print(f" Successfully spawned workers: {list(spawn_wg.keys())}") - - # --- Log assigned resources --- - # Note: Getting precise GPU IDs assigned by Ray to specific actors - # after spawn can be tricky from the outside. - # We'll rely on checks *inside* the worker for now. - # Logging the group's overall placement gives some clue. - if hasattr(wg_dict, 'get_placement_group') and callable(getattr(wg_dict, 'get_placement_group')): - pg = wg_dict.get_placement_group() - if pg: - print(f" Group {resource_pool.name_prefix} placement group details: {pg.bundle_specs}") - else: - print(f" Group {resource_pool.name_prefix} does not have a placement group.") - else: - print(f" Cannot get placement group details for group {resource_pool.name_prefix}.") - - except Exception as e: - print(f"[ERROR] Failed to spawn workers for group {resource_pool.name_prefix}: {e}") - import traceback - traceback.print_exc() - raise # Re-raise the exception to stop execution + spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) + all_wg.update(spawn_wg) + # keep the referece of WorkerDict to support ray >= 2.31. Ref: https://github.com/ray-project/ray/pull/45699 + self.wg_dicts.append(wg_dict) - # --- Assign worker groups --- - # Use .get for safety in case spawning failed for a group if self.use_critic: - self.critic_wg = all_wg.get('critic') - if self.critic_wg: - print("[Trainer.init_workers] Initializing Critic model...") - # TODO: Modify init_model call to pass assigned GPU IDs if known - self.critic_wg.init_model() - else: - print("[Trainer.init_workers][ERROR] Critic worker group not found after spawn.") - # Decide how to handle this - maybe raise an error? + self.critic_wg = all_wg['critic'] + self.critic_wg.init_model() if self.use_reference_policy: - self.ref_policy_wg = all_wg.get('ref') - if self.ref_policy_wg: - print("[Trainer.init_workers] Initializing RefPolicy model...") - # TODO: Modify init_model call - self.ref_policy_wg.init_model() - else: - print("[Trainer.init_workers][ERROR] RefPolicy worker group not found after spawn.") + self.ref_policy_wg = all_wg['ref'] + self.ref_policy_wg.init_model() if self.use_rm: - self.rm_wg = all_wg.get('rm') - if self.rm_wg: - print("[Trainer.init_workers] Initializing RewardModel model...") - # TODO: Modify init_model call - self.rm_wg.init_model() - else: - print("[Trainer.init_workers][ERROR] RewardModel worker group not found after spawn.") - - # Initialize actor_rollout last - self.actor_rollout_wg = all_wg.get('actor_rollout') - if self.actor_rollout_wg: - print("[Trainer.init_workers] Initializing ActorRollout model...") - # TODO: Modify init_model call - self.actor_rollout_wg.init_model() - else: - print("[Trainer.init_workers][ERROR] ActorRollout worker group not found after spawn.") - - # --- Verify CUDA setup for each initialized worker group --- - print("\n[Trainer.init_workers] Verifying CUDA setup for initialized workers...") - if self.actor_rollout_wg: - self.verify_worker_cuda_setup("actor_rollout", self.actor_rollout_wg) - if self.use_critic and self.critic_wg: - self.verify_worker_cuda_setup("critic", self.critic_wg) - if self.use_reference_policy and self.ref_policy_wg: - self.verify_worker_cuda_setup("ref_policy", self.ref_policy_wg) - if self.use_rm and self.rm_wg: - self.verify_worker_cuda_setup("reward_model", self.rm_wg) - - print("[Trainer.init_workers] Worker initialization and verification complete.") + self.rm_wg = all_wg['rm'] + self.rm_wg.init_model() + + # we should create rollout at the end so that vllm can have a better estimation of kv cache memory + self.actor_rollout_wg = all_wg['actor_rollout'] + self.actor_rollout_wg.init_model() def _save_checkpoint(self): actor_local_path = os.path.join(self.config.trainer.default_local_dir, 'actor', @@ -1219,17 +925,16 @@ def fit(self): # Define log_dir here based on config self.log_dir = self.config.trainer.get("default_local_dir", "./verl_checkpoints/default_log_dir") os.makedirs(self.log_dir, exist_ok=True) - print(f"[Trainer.fit] Log directory set to: {self.log_dir}") + print(f"[Trainer.fit][DEBUG] Log directory set to: {self.log_dir}") # DEBUG # Determine if this is an AgentGym run upfront self.is_agentgym_run = self.config.data.env_name in KNOWN_AGENTGYM_ENVS - print(f"[Trainer.fit] Is AgentGym run: {self.is_agentgym_run}") + print(f"[Trainer.fit][DEBUG] Is AgentGym run: {self.is_agentgym_run}") # DEBUG # Get advantage estimator strategy adv_estimator = self.config.algorithm.adv_estimator - print(f"[Trainer.fit] Using advantage estimator: {adv_estimator}") - # <<< Add log here >>> - print(f"[Trainer.fit] Value of self.use_critic at start of loop: {self.use_critic}") + print(f"[Trainer.fit][DEBUG] Using advantage estimator: {adv_estimator}") # DEBUG + print(f"[Trainer.fit][DEBUG] Value of self.use_critic at start: {self.use_critic}") # DEBUG # 如果使用GRPO但仍然设置了use_critic为True,发出警告 if adv_estimator == 'grpo' and self.use_critic: @@ -1241,7 +946,7 @@ def fit(self): # Agent config preparation (Only needed if AgentGym run) generation_manager = None if self.is_agentgym_run: - print(f"[Trainer.fit] Initializing OpenManusAgent for AgentGym environment: {self.config.data.env_name}") + print(f"[Trainer.fit][DEBUG] Initializing OpenManusAgent for AgentGym environment: {self.config.data.env_name}") # DEBUG try: gen_config = AgentConfig( max_turns=self.config.max_turns, @@ -1257,10 +962,10 @@ def fit(self): max_workers=self.config.actor_rollout_ref.rollout.get('max_workers', 10), algorithm_config=self.config.algorithm, ) - print(f"[Trainer.fit] AgentConfig initialized successfully") + print(f"[Trainer.fit][DEBUG] AgentConfig initialized successfully") # DEBUG agent_logger = self.logger if hasattr(self, 'logger') else None - print(f"[Trainer.fit] Creating OpenManusAgent...") + print(f"[Trainer.fit][DEBUG] Creating OpenManusAgent...") # DEBUG generation_manager = OpenManusAgent( tokenizer=self.tokenizer, actor_rollout_wg=self.actor_rollout_wg, @@ -1268,7 +973,7 @@ def fit(self): is_validation = False, logger=agent_logger ) - print(f"[Trainer.fit] OpenManusAgent created successfully") + print(f"[Trainer.fit][DEBUG] OpenManusAgent created successfully") # DEBUG except Exception as e: print(f"[Trainer.fit][ERROR] Failed to initialize OpenManusAgent: {e}") import traceback @@ -1276,20 +981,29 @@ def fit(self): raise # start training loop - print(f"[Trainer.fit] Starting training loop for {self.config.trainer.total_epochs} epochs") + print(f"[Trainer.fit][DEBUG] Starting training loop for {self.config.trainer.total_epochs} epochs") for epoch in range(self.config.trainer.total_epochs): - print(f"[Trainer.fit] Starting epoch {epoch}") + print(f"[Trainer.fit][DEBUG] Starting Epoch {epoch}") # DEBUG for batch_idx, batch_dict in enumerate(self.train_dataloader): - print(f"[Trainer.fit][STEP] === Epoch {epoch}, Step {self.global_steps}, Batch {batch_idx} ===") + print(f"[Trainer.fit][DEBUG] === Starting Epoch {epoch}, Step {self.global_steps}, Batch {batch_idx} ===") # DEBUG metrics = {} timing_raw = {} - print(f"[Trainer.fit][STEP {self.global_steps}] Creating DataProto from batch dictionary") + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Creating DataProto from batch dictionary.") # DEBUG batch: DataProto = DataProto.from_single_dict(batch_dict) original_batch_size = batch.batch['input_ids'].shape[0] - print(f"[Trainer.fit][STEP {self.global_steps}] Original batch size: {original_batch_size}") + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Original batch size: {original_batch_size}") # DEBUG + + # --- Debug Print: Initial Batch State --- + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Initial Batch Info:") + print(f" Batch Keys & Shapes & Devices:") + # Check if batch attribute exists, is not None, and is not empty + print(f" Meta Info Keys: {list(batch.meta_info.keys()) if hasattr(batch, 'meta_info') else 'N/A'}") + print(f" Non-Tensor Batch Keys: {list(batch.non_tensor_batch.keys()) if hasattr(batch, 'non_tensor_batch') else 'N/A'}") + # --- End Debug Print --- # Keep necessary keys for agent/rollout in gen_batch + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Popping keys for generation batch.") # DEBUG gen_batch = batch.pop(batch_keys=[ 'input_ids', 'attention_mask', 'position_ids' ]) @@ -1302,383 +1016,371 @@ def fit(self): #################### # Rollout / Generation Step #################### - print(f"[Trainer.fit][STEP {self.global_steps}] Starting rollout/generation step") + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Starting rollout/generation step.") # DEBUG final_gen_batch_output = None with _timer('step', timing_raw): if self.is_agentgym_run: - # --- AgentGym Path --- - print(f"[Trainer.fit][STEP {self.global_steps}] Using AgentGym path") + # --- AgentGym Path --- + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Using AgentGym path.") # DEBUG with _timer('gen', timing_raw): - # Prepare output directory if logging images during training (less common) output_dir = os.path.join( - self.log_dir, # Use the defined log_dir + self.log_dir, f"train_step_{self.global_steps}" ) - print(f"[Trainer.fit][STEP {self.global_steps}] Calling generation_manager.run_llm_loop...") + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Calling generation_manager.run_llm_loop...") # DEBUG + # --- Debug Print: Input to run_llm_loop --- + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Input gen_batch to run_llm_loop:") + print(f" Batch Keys & Shapes & Devices:") + # Check if batch attribute exists, is not None, and is not empty + print(f" Meta Info Keys: {list(gen_batch.meta_info.keys()) if hasattr(gen_batch, 'meta_info') else 'N/A'}") + # --- End Debug Print --- try: final_gen_batch_output = generation_manager.run_llm_loop( gen_batch=gen_batch, output_dir=output_dir, global_steps=self.global_steps ) - print(f"[Trainer.fit][STEP {self.global_steps}] Returned from generation_manager.run_llm_loop") + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Returned from generation_manager.run_llm_loop.") # DEBUG + # --- Debug Print: Output from run_llm_loop --- + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Output final_gen_batch_output from run_llm_loop:") + print(f" Batch Keys & Shapes & Devices:") + # Check if batch attribute exists, is not None, and is not empty + print(f" Meta Info Keys: {list(final_gen_batch_output.meta_info.keys()) if hasattr(final_gen_batch_output, 'meta_info') else 'N/A'}") + # --- End Debug Print --- except Exception as e: - print(f"[Trainer.fit][STEP {self.global_steps}][ERROR] Encountered error in run_llm_loop: {e}") + print(f"[Trainer.fit][ERROR] Step {self.global_steps}: Encountered error in run_llm_loop: {e}") # ERROR import traceback traceback.print_exc() continue # Skip to next batch if rollout failed if not final_gen_batch_output or final_gen_batch_output.batch is None or final_gen_batch_output.batch.is_empty(): - print(f"[Trainer.fit][STEP {self.global_steps}][ERROR] AgentGym rollout returned empty batch. Skipping step.") - # Instead of continue, raise an error to halt if this is unexpected + print(f"[Trainer.fit][ERROR] Step {self.global_steps}: AgentGym rollout returned empty batch. Skipping step.") # ERROR raise RuntimeError(f"AgentGym rollout returned empty batch at step {self.global_steps}") # Add log probs (needed for PPO loss calculation later) - print(f"[Trainer.fit][STEP {self.global_steps}] Computing log probabilities") + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Computing log probabilities.") # DEBUG with torch.no_grad(), _timer('logp', timing_raw): if 'input_ids' in final_gen_batch_output.batch: actor_rollout_world_size = self.actor_rollout_wg.world_size - print(f"[Trainer.fit][STEP {self.global_steps}] ActorRollout world size for padding: {actor_rollout_world_size}") + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: ActorRollout world size for padding: {actor_rollout_world_size}") # DEBUG padded_batch_for_logp, pad_size_logp = pad_dataproto_to_divisor( final_gen_batch_output, actor_rollout_world_size ) if pad_size_logp > 0: - print(f"[Trainer.fit][STEP {self.global_steps}] Padded batch for compute_log_prob by {pad_size_logp} samples.") - + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Padded batch for compute_log_prob by {pad_size_logp} samples.") # DEBUG + + # --- Populate meta_info for actor_rollout_wg.compute_log_prob --- + # These parameters are expected by DataParallelActor.compute_log_prob on the worker. + # Source them from self.config.actor_rollout_ref.rollout logp_mbs = self.config.actor_rollout_ref.rollout.log_prob_micro_batch_size - padded_batch_for_logp.meta_info['micro_batch_size'] = logp_mbs + use_dyn_bsz = self.config.actor_rollout_ref.rollout.get('log_prob_use_dynamic_bsz', False) temperature = self.config.actor_rollout_ref.rollout.temperature - padded_batch_for_logp.meta_info['temperature'] = temperature - use_dyn_bsz = self.config.actor_rollout_ref.rollout.log_prob_use_dynamic_bsz + print(f"[DEBUG][Trainer.fit] Step {self.global_steps}: Sourced config for actor.compute_log_prob: log_prob_micro_batch_size={logp_mbs}, log_prob_use_dynamic_bsz={use_dyn_bsz}, temperature={temperature}") # DEBUG + padded_batch_for_logp.meta_info['micro_batch_size'] = logp_mbs padded_batch_for_logp.meta_info['use_dynamic_bsz'] = use_dyn_bsz + padded_batch_for_logp.meta_info['temperature'] = temperature + if use_dyn_bsz: + max_token_len_logp = self.config.actor_rollout_ref.rollout.get( + 'log_prob_max_token_len_per_gpu', + self.config.data.max_prompt_length + ) + padded_batch_for_logp.meta_info['max_token_len'] = max_token_len_logp + print(f"[DEBUG][Trainer.fit] Step {self.global_steps}: For dynamic log_prob batching, set max_token_len={max_token_len_logp}") # DEBUG + else: + padded_batch_for_logp.meta_info.pop('max_token_len', None) + print(f"[DEBUG][Trainer.fit] Step {self.global_steps}: Final padded_batch_for_logp.meta_info for compute_log_prob: {padded_batch_for_logp.meta_info}") # DEBUG + # --- End of meta_info population --- - print(f"[Trainer.fit][STEP {self.global_steps}] Calling actor_rollout_wg.compute_log_prob with (potentially) padded batch...") + # --- Debug Print: Input to compute_log_prob --- + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Input padded_batch_for_logp to compute_log_prob:") + print(f" Batch Keys & Shapes & Devices:") + # Check if batch attribute exists, is not None, and is not empty + print(f" Meta Info: {padded_batch_for_logp.meta_info if hasattr(padded_batch_for_logp, 'meta_info') else 'N/A'}") + # --- End Debug Print --- + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Calling actor_rollout_wg.compute_log_prob...") # DEBUG try: output_logp_padded = self.actor_rollout_wg.compute_log_prob(padded_batch_for_logp) output_logp = unpad_dataproto(output_logp_padded, pad_size=pad_size_logp) if pad_size_logp > 0: - print(f"[Trainer.fit][STEP {self.global_steps}] Unpadded log_prob output.") + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Unpadded log_prob output.") # DEBUG final_gen_batch_output = final_gen_batch_output.union(output_logp) - print(f"[Trainer.fit][STEP {self.global_steps}] Log probabilities computed successfully") + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Log probabilities computed successfully.") # DEBUG except Exception as e: - print(f"[Trainer.fit][STEP {self.global_steps}][ERROR] Error computing log probabilities: {e}") + print(f"[Trainer.fit][ERROR] Step {self.global_steps}: Error computing log probabilities: {e}") # ERROR traceback.print_exc() raise # Re-raise to halt execution else: - print(f"[Trainer.fit][STEP {self.global_steps}][ERROR] Cannot compute log probabilities, 'input_ids' not found in batch") + print(f"[Trainer.fit][ERROR] Step {self.global_steps}: Cannot compute log probabilities, 'input_ids' not found in batch.") # ERROR raise RuntimeError("Cannot compute log_prob, input_ids missing") # Halt execution - batch = final_gen_batch_output - print(f"[Trainer.fit][STEP {self.global_steps}] Setting up token_level_scores") + batch = final_gen_batch_output # Update batch with the results + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Setting up token_level_scores.") # DEBUG if 'token_level_rewards' in batch.batch: batch.batch['token_level_scores'] = batch.batch['token_level_rewards'].clone() - print(f"[Trainer.fit][STEP {self.global_steps}] Cloned token_level_rewards to token_level_scores") + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Cloned token_level_rewards to token_level_scores.") # DEBUG else: - print(f"[Trainer.fit][STEP {self.global_steps}][ERROR] 'token_level_rewards' not found in batch. Creating zero scores.") + print(f"[Trainer.fit][ERROR] Step {self.global_steps}: 'token_level_rewards' not found in batch after run_llm_loop. Creating zero scores.") # ERROR if 'input_ids' in batch.batch: batch.batch['token_level_scores'] = torch.zeros_like(batch.batch['input_ids'], dtype=torch.float) else: - print(f"[Trainer.fit][STEP {self.global_steps}][ERROR] Cannot create zero 'token_level_scores' because 'input_ids' is missing.") + print(f"[Trainer.fit][ERROR] Step {self.global_steps}: Cannot create zero 'token_level_scores' because 'input_ids' is missing.") # ERROR continue - # --- FIX: Convert UID list to NumPy array with dtype=object --- - print(f"[Trainer.fit][STEP {self.global_steps}] Setting up UID for batch") + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Setting up UID for batch.") # DEBUG if 'idx' in batch.meta_info: - # Ensure idx tensor is moved to CPU before converting to list uid_list = batch.meta_info['idx'].cpu().tolist() - batch.non_tensor_batch['uid'] = np.array(uid_list, dtype=object) # Explicitly set dtype=object - print(f"[Trainer.fit][STEP {self.global_steps}] Used existing idx as UID") - else: # Fallback UID + batch.non_tensor_batch['uid'] = np.array(uid_list, dtype=object) + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Used existing idx as UID.") # DEBUG + else: uid_list = [str(uuid.uuid4()) for _ in range(batch.batch['input_ids'].shape[0])] - batch.non_tensor_batch['uid'] = np.array(uid_list, dtype=object) # Explicitly set dtype=object - print(f"[Trainer.fit][STEP {self.global_steps}] Created new UUIDs as UID") - # --- END FIX --- + batch.non_tensor_batch['uid'] = np.array(uid_list, dtype=object) + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Created new UUIDs as UID.") # DEBUG else: - # --- Original Path (Non-AgentGym) --- - print(f"[Trainer.fit][STEP {self.global_steps}][ERROR] Non-AgentGym training path not implemented. Skipping.") + # --- Original Path (Non-AgentGym) --- + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Using Non-AgentGym generation path.") # DEBUG + # Add debug logs for non-agentgym path if needed + print(f"[Trainer.fit][ERROR] Step {self.global_steps}: Non-AgentGym training path not fully implemented with debug logs. Skipping.") # ERROR continue # Skip processing for now # Apply batch repetition if configured (AFTER generation/rollout) if self.config.actor_rollout_ref.rollout.n > 1: - print(f"[Trainer.fit][STEP {self.global_steps}] Repeating batch {self.config.actor_rollout_ref.rollout.n} times") + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Repeating batch {self.config.actor_rollout_ref.rollout.n} times.") # DEBUG batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) #################### # Post-Rollout Processing (Common for both paths after merging) #################### - print(f"[Trainer.fit][STEP {self.global_steps}] Balancing batch") + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Starting post-rollout processing.") # DEBUG + + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Balancing batch...") # DEBUG self._balance_batch(batch, metrics=metrics) - print(f"[Trainer.fit][STEP {self.global_steps}] Batch balanced successfully") + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Batch balanced successfully.") # DEBUG - # --- COMPLETELY RESTRUCTURED COMPUTATION FLOW --- - # Follow verl implementation pattern: First compute critic values, then compute advantages once + # compute global_valid tokens (Maybe move after all data is ready?) + if 'attention_mask' in batch.batch: + batch.meta_info['global_token_num'] = torch.sum(batch.batch['attention_mask'], dim=-1).tolist() + else: + print(f"[Trainer.fit][WARN] Step {self.global_steps}: 'attention_mask' not in batch for global_token_num calculation.") + + # --- Compute Reference Log Probs --- + if self.use_reference_policy: + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Computing reference log probs...") # DEBUG + with _timer('ref', timing_raw): + # --- Debug Print: Input to compute_ref_log_prob --- + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Input batch to compute_ref_log_prob:") + print(f" Batch Keys & Shapes & Devices:") + # Check if batch attribute exists, is not None, and is not empty + # --- End Debug Print --- + ref_log_prob_output = self.ref_policy_wg.compute_ref_log_prob(batch) + batch = batch.union(ref_log_prob_output) + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Reference log probs computed.") # DEBUG + else: + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Skipping reference log prob computation.") # DEBUG - # --- 1. Compute Critic Values (if needed for GAE) --- + # --- Compute Critic Values --- if self.use_critic and adv_estimator == 'gae': - print(f"[DEBUG] ****** COMPUTING CRITIC VALUES (Step: {self.global_steps}) ******") - print(f"[Trainer.fit][STEP {self.global_steps}] Computing critic values for GAE") - print(f"[DEBUG] Before values computation, batch keys: {list(batch.batch.keys())}") - + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Computing critic values for GAE...") # DEBUG with _timer('compute_values', timing_raw): - try: - # REMOVED: Logic to get worker device and move tensors to CUDA - # TaskRunner should not perform device placement. - - # Check current device for logging purposes - ref_tensor = None - current_device = 'cpu' # Default assumption - for key in ['input_ids', 'attention_mask', 'position_ids']: - if key in batch.batch: - ref_tensor = batch.batch[key] - current_device = ref_tensor.device - break - - if ref_tensor is not None: - print(f"[DEBUG] Current batch tensor device: {current_device}") - - # Call critic worker to compute values - pass tensors as they are - print(f"[DEBUG] Sending batch to critic_wg.compute_values (tensors on {current_device})...") - values_output = self.critic_wg.compute_values(batch) - - # Check if values were returned correctly - if 'values' in values_output.batch: - values_tensor = values_output.batch['values'] - print(f"[DEBUG] Values computed successfully: shape={values_tensor.shape}, device={values_tensor.device}") - - # Directly assign values to batch (avoiding union operation) - batch.batch['values'] = values_tensor.clone() # Use clone for safety - - # Create a backup copy for safety - self._values_backup = values_tensor.clone() - print(f"[DEBUG] Values assigned to batch and backup created") - print(f"[DEBUG] After values assignment, batch keys: {list(batch.batch.keys())}") - else: - raise ValueError("CriticWorker.compute_values did not return required 'values' field") - except Exception as e: - print(f"[ERROR] Failed to compute critic values: {e}") - import traceback - traceback.print_exc() - continue # Skip to next batch if values computation failed - - # --- 2. Compute Advantages (ONLY ONCE) --- - print(f"[DEBUG] ****** COMPUTING ADVANTAGES (Step: {self.global_steps}) ******") - print(f"[Trainer.fit][STEP {self.global_steps}] Computing advantages with estimator: {adv_estimator}") - print(f"[DEBUG] Before advantage computation, batch keys: {list(batch.batch.keys())}") - - # Safety check for GAE - ensure values are present - if self.use_critic and adv_estimator == 'gae' and 'values' not in batch.batch: - if hasattr(self, '_values_backup'): - print(f"[WARNING] Values key missing before advantage computation - restoring from backup") - batch.batch['values'] = self._values_backup.clone() - else: - print(f"[ERROR] Values required for GAE but missing from batch and no backup available") - continue # Skip this batch - - # Get device for compute_advantage computation - # (ideally should match the device of the batch tensors) - target_device = 'cuda' if torch.cuda.is_available() else 'cpu' - - # Check if all tensors are on the same device - device_check = {} - for key, tensor in batch.batch.items(): - if isinstance(tensor, torch.Tensor): - device_check[key] = tensor.device - - if len(set(str(dev) for dev in device_check.values())) > 1: - print(f"[WARNING] Detected tensors on different devices: {device_check}") - print(f"[DEBUG] Moving all tensors to {target_device} for consistent computation") - - # Move all tensors to the target device - for key, tensor in batch.batch.items(): - if isinstance(tensor, torch.Tensor) and str(tensor.device) != str(target_device): - batch.batch[key] = tensor.to(target_device) - - # Log key tensor devices for debugging - if 'values' in batch.batch: - print(f"[DEBUG] Device for values: {batch.batch['values'].device}") - if 'token_level_rewards' in batch.batch: - print(f"[DEBUG] Device for token_level_rewards: {batch.batch['token_level_rewards'].device}") - - with _timer('adv', timing_raw): - try: - # SINGLE advantage computation - batch = compute_advantage( - data=batch, - adv_estimator=adv_estimator, - gamma=self.config.algorithm.get('gamma', 1.0), - lam=self.config.algorithm.get('lambda', 1.0) - ) - print(f"[DEBUG] Advantages computed successfully") - print(f"[DEBUG] After advantage computation, batch keys: {list(batch.batch.keys())}") + # Read config values for critic + critic_mbs = self.config.critic.ppo_micro_batch_size + critic_use_dyn_bsz = self.config.critic.get('use_dynamic_bsz', False) + print(f"[CRITICAL DEBUG][Trainer.fit] Step {self.global_steps}: Value read from self.config.critic.ppo_micro_batch_size = {critic_mbs}") + print(f"[CRITICAL DEBUG][Trainer.fit] Step {self.global_steps}: Value read from self.config.critic.get('use_dynamic_bsz') = {critic_use_dyn_bsz}") + + # --- Create a temporary, isolated DataProto for the compute_values call --- + # 1. Create a dedicated meta_info dictionary + critic_meta_info = {} + critic_meta_info['micro_batch_size'] = critic_mbs + critic_meta_info['use_dynamic_bsz'] = critic_use_dyn_bsz + if critic_use_dyn_bsz: + critic_meta_info['max_token_len'] = self.config.critic.get('ppo_max_token_len_per_gpu', 2048) + # No need to pop max_token_len if false, it just won't be added + print(f"[DEBUG][Trainer.fit] Step {self.global_steps}: Prepared TEMPORARY critic_meta_info for compute_values: {critic_meta_info}") # DEBUG - # Check device of computed advantages - if 'advantages' in batch.batch: - print(f"[DEBUG] Device for advantages: {batch.batch['advantages'].device}") - if 'returns' in batch.batch: - print(f"[DEBUG] Device for returns: {batch.batch['returns'].device}") - except Exception as e: - print(f"[ERROR] Failed to compute advantages: {e}") - import traceback - traceback.print_exc() - continue # Skip to next batch if advantage computation failed - - # --- KL Penalty (if using reference policy) --- - if self.use_reference_policy and 'ref_log_prob' in batch.batch and 'old_log_probs' in batch.batch: - print(f"[Trainer.fit][STEP {self.global_steps}] Applying KL penalty") - with _timer('kl_penalty', timing_raw): - try: - batch, kl_metrics = apply_kl_penalty(batch, self.kl_ctrl, kl_penalty=self.config.algorithm.get('kl_penalty', 'kl')) - metrics.update(kl_metrics) - print(f"[Trainer.fit][STEP {self.global_steps}] KL penalty applied successfully") - except Exception as e: - print(f"[Trainer.fit][STEP {self.global_steps}][ERROR] Error applying KL penalty: {e}") - import traceback - traceback.print_exc() - # Continue anyway, this isn't critical - Keeping this as continue, as KL might not be essential - - # --- Compute Critic Values --- - if self.use_critic: - print(f"[Trainer.fit][STEP {self.global_steps}] Updating critic model") - if 'advantages' not in batch.batch or 'returns' not in batch.batch: - print(f"[Trainer.fit][STEP {self.global_steps}][ERROR] Missing 'advantages' or 'returns' in batch, required for critic update. Skipping critic update.") - continue # We change this from a warning to error and skip the batch - else: - with _timer('update_critic', timing_raw): - print(f"[Trainer.fit][STEP {self.global_steps}] Calling critic_wg.update_critic...") - try: - # REMOVED: Explicit device checking and moving logic before calling worker - # The worker itself should handle device placement. - - # Log tensor devices for debugging purposes before sending - adv_device = batch.batch['advantages'].device - returns_device = batch.batch['returns'].device - print(f"[DEBUG] Pre-critic update tensor devices (in TaskRunner): advantages={adv_device}, returns={returns_device}") - - # Call update_critic - critic_output = self.critic_wg.update_critic(batch) - - # Process results (assuming they are returned to CPU or handled correctly) - critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics']) - metrics.update(critic_output_metrics) - print(f"[Trainer.fit][STEP {self.global_steps}] Critic model updated successfully") - except Exception as e: - print(f"[Trainer.fit][STEP {self.global_steps}][ERROR] Error updating critic: {e}") - import traceback - traceback.print_exc() - raise # Re-raise to halt execution + # 2. Select only the necessary tensors from the original batch + critic_input_tensors = { + key: batch.batch[key] + for key in ['responses', 'input_ids', 'attention_mask', 'position_ids'] + if key in batch.batch + } + if len(critic_input_tensors) != 4: + print(f"[WARN][Trainer.fit] Step {self.global_steps}: Missing some required keys for critic compute_values in batch.batch. Found: {list(critic_input_tensors.keys())}") + + # 3. Create the temporary DataProto object + critic_input_proto = DataProto.from_dict(critic_input_tensors) + critic_input_proto.meta_info = critic_meta_info # Assign the dedicated meta_info + + # --- Debug print the temporary object --- + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: TEMPORARY Input critic_input_proto to compute_values:") + print(f" Batch Keys & Shapes & Devices:") + print(f" Meta Info: {critic_input_proto.meta_info}") + # --- End Debug print --- + + # 4. Call the worker with the temporary object + values_output = self.critic_wg.compute_values(critic_input_proto) + # --- End modification for temporary object --- + + # --- Debug Print: Output from compute_values --- + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Output from compute_values:") + # --- End Debug Print --- + batch = batch.union(values_output) + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Critic values computed and merged.") # DEBUG + elif self.use_critic and adv_estimator != 'gae': + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Critic exists but not used for GAE, skipping compute_values for advantage calculation.") # DEBUG else: - print(f"[Trainer.fit][STEP {self.global_steps}] Skipping critic update (not enabled for {adv_estimator}) or missing required data") + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Critic not used, skipping compute_values.") # DEBUG - # --- Update Actor --- - print(f"[Trainer.fit][STEP {self.global_steps}] Updating actor model") - if self.config.trainer.critic_warmup <= self.global_steps: - if 'advantages' not in batch.batch or 'old_log_probs' not in batch.batch: - print(f"[Trainer.fit][STEP {self.global_steps}][ERROR] Missing 'advantages' or 'old_log_probs' in batch, required for actor update. Skipping actor update.") - # Instead of continue, raise an error - raise RuntimeError("Missing required data for actor update") + # --- Apply Reward Function / KL Penalty --- + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Computing scores and rewards...") # DEBUG + with _timer('adv', timing_raw): + # Compute scores (potentially using RM) + if self.use_rm: + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Computing RM score...") # DEBUG + reward_tensor = self.rm_wg.compute_rm_score(batch) + batch = batch.union(reward_tensor) + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: RM score computed and merged.") # DEBUG + + # Combine with rule-based/external reward function if provided + if self.reward_fn is not None: + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Applying external reward_fn...") # DEBUG + reward_tensor = self.reward_fn(batch) # Assuming reward_fn returns the tensor directly + batch.batch['token_level_scores'] = reward_tensor + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: External reward_fn applied.") # DEBUG + elif 'reward_model_scores' in batch.batch: # If only RM was used + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Using RM scores as token_level_scores.") # DEBUG + batch.batch['token_level_scores'] = batch.batch['reward_model_scores'] + elif 'token_level_scores' not in batch.batch: + print(f"[Trainer.fit][WARN] Step {self.global_steps}: 'token_level_scores' not found after RM/reward_fn. Ensure one is active or rewards are set elsewhere.") # WARN + # If scores are set directly by agentgym run_llm_loop, this might be okay. + + # Apply KL penalty (modifies token_level_scores -> token_level_rewards) + if self.use_reference_policy and not self.config.actor_rollout_ref.actor.get('use_kl_loss', False): + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Applying KL penalty...") # DEBUG + batch, kl_metrics = apply_kl_penalty(batch, + kl_ctrl=self.kl_ctrl, + kl_penalty=self.config.algorithm.get('kl_penalty', 'kl')) # Use .get + metrics.update(kl_metrics) + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: KL penalty applied.") # DEBUG else: - with _timer('update_actor', timing_raw): - print(f"[Trainer.fit][STEP {self.global_steps}] Calling actor_rollout_wg.update_actor...") - try: - if self.is_agentgym_run and hasattr(self.config.actor_rollout_ref.actor, 'state_masking') and self.config.actor_rollout_ref.actor.state_masking: - print(f"[Trainer.fit][STEP {self.global_steps}] State masking is enabled, creating loss_mask") - batch, actor_metrics = self._create_loss_mask(batch, metrics) - metrics.update(actor_metrics) - else: - print(f"[Trainer.fit][STEP {self.global_steps}] State masking is not enabled, creating default loss_mask") - response_length = batch.batch['responses'].shape[-1] - batch.batch['loss_mask'] = torch.ones_like(batch.batch['attention_mask'][:, -response_length:]) - - loss_mask_device = batch.batch['loss_mask'].device - adv_device = batch.batch['advantages'].device - old_log_probs_device = batch.batch['old_log_probs'].device - print(f"[DEBUG] Pre-actor update tensor devices (in TaskRunner): loss_mask={loss_mask_device}, advantages={adv_device}, old_log_probs={old_log_probs_device}") - - actor_output = self.actor_rollout_wg.update_actor(batch) - - actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics']) - metrics.update(actor_output_metrics) - print(f"[Trainer.fit][STEP {self.global_steps}] Actor model updated successfully") - except Exception as e: - print(f"[Trainer.fit][STEP {self.global_steps}][ERROR] Error updating actor: {e}") - import traceback - traceback.print_exc() - raise # Re-raise to halt execution + if 'token_level_rewards' not in batch.batch and 'token_level_scores' in batch.batch: + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Setting token_level_rewards = token_level_scores (no KL penalty/loss).") # DEBUG + batch.batch['token_level_rewards'] = batch.batch['token_level_scores'].clone() + elif 'token_level_rewards' not in batch.batch: + print(f"[Trainer.fit][WARN] Step {self.global_steps}: 'token_level_rewards' not set and KL penalty not applied.") # WARN + + # --- Compute Advantages --- + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Computing advantages (estimator: {adv_estimator})...") # DEBUG + # --- Debug Print: Input to compute_advantage --- + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Input batch to compute_advantage:") + print(f" Batch Keys & Shapes & Devices:") + # Check if batch attribute exists, is not None, and is not empty + batch = compute_advantage(batch, + adv_estimator=adv_estimator, + gamma=self.config.algorithm.get('gamma', 1.0), + lam=self.config.algorithm.get('lambda', 1.0), + num_repeat=self.config.actor_rollout_ref.rollout.get('n', 1)) # Use .get + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Advantages computed.") # DEBUG + + # --- Update Critic --- + if self.use_critic: + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Updating critic...") # DEBUG + with _timer('update_critic', timing_raw): + # --- Debug Print: Input to update_critic --- + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Input batch to update_critic:") + print(f" Batch Keys & Shapes & Devices:") + # Check if batch attribute exists, is not None, and is not empty + # --- End Debug Print --- + critic_output = self.critic_wg.update_critic(batch) # Returns DataProto with metrics + if hasattr(critic_output, 'meta_info') and 'metrics' in critic_output.meta_info: + critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics']) + metrics.update(critic_output_metrics) + else: + print(f"[Trainer.fit][WARN] Step {self.global_steps}: Critic update did not return metrics in meta_info.") + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Critic updated.") # DEBUG else: - print(f"[Trainer.fit][STEP {self.global_steps}] Skipping actor update (in critic warmup phase)") + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Skipping critic update.") # DEBUG - # --- Save Checkpoint --- - if self.config.trainer.save_freq > 0 and self.global_steps % self.config.trainer.save_freq == 0: - print(f"[Trainer.fit][STEP {self.global_steps}] Saving checkpoint") - with _timer('save_checkpoint', timing_raw): - try: - self._save_checkpoint() - print(f"[Trainer.fit][STEP {self.global_steps}] Checkpoint saved successfully") - except Exception as e: - print(f"[Trainer.fit][STEP {self.global_steps}][ERROR] Error saving checkpoint: {e}") - import traceback - traceback.print_exc() - # Saving checkpoint might fail due to FS issues, allow continuation but log error - # raise # Optional: Uncomment to halt on checkpoint save failure - - # --- Collect and Log Metrics --- - print(f"[Trainer.fit][STEP {self.global_steps}] Collecting and logging metrics") - try: - # Check for necessary keys before computing metrics - required_keys = ['token_level_scores', 'token_level_rewards', 'advantages', 'returns', 'responses', 'attention_mask'] - if self.use_critic: required_keys.append('values') - # Add meta_info keys needed for env metrics - required_meta_keys = ['turns_stats', 'active_mask', 'valid_action_stats', 'valid_search_stats'] - - # Ensure all required meta keys exist (add defaults if missing) - for meta_key in required_meta_keys: - if meta_key not in batch.meta_info: - if meta_key == 'active_mask': - batch.meta_info[meta_key] = np.ones(batch.batch['input_ids'].shape[0], dtype=np.int16) + # --- Update Actor --- + if self.config.trainer.get('critic_warmup', 0) <= self.global_steps: # Use .get + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Updating actor...") # DEBUG + with _timer('update_actor', timing_raw): + # state masking is only applicable for search agent + if self.is_agentgym_run and hasattr(self.config.actor_rollout_ref.actor, 'state_masking') and self.config.actor_rollout_ref.actor.state_masking: + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Applying state masking...") # DEBUG + batch, actor_metrics = self._create_loss_mask(batch, metrics) + metrics.update(actor_metrics) + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: State masking applied.") # DEBUG + elif 'loss_mask' not in batch.batch: + # Ensure loss_mask exists if not using state masking (defaults to response mask) + if 'responses' in batch.batch and 'attention_mask' in batch.batch: + response_length = batch.batch['responses'].shape[-1] + batch.batch['loss_mask'] = batch.batch['attention_mask'][:, -response_length:].clone() + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Created default loss_mask (response mask).") # DEBUG + else: + print(f"[Trainer.fit][WARN] Step {self.global_steps}: Cannot create default loss_mask, missing 'responses' or 'attention_mask'.") # WARN + + # --- Debug Print: Input to update_actor --- + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Input batch to update_actor:") + print(f" Batch Keys & Shapes & Devices:") + # Check if batch attribute exists, is not None, and is not empty + # --- End Debug Print --- + actor_output = self.actor_rollout_wg.update_actor(batch) # Returns DataProto with metrics + if hasattr(actor_output, 'meta_info') and 'metrics' in actor_output.meta_info: + actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics']) + metrics.update(actor_output_metrics) else: - batch.meta_info[meta_key] = np.zeros(batch.batch['input_ids'].shape[0], dtype=np.int16) - print(f"[Trainer.fit][STEP {self.global_steps}] Added default value for missing meta key: {meta_key}") - - can_compute_metrics = all(key in batch.batch for key in required_keys) and all(key in batch.meta_info for key in required_meta_keys) - if can_compute_metrics: - print(f"[Trainer.fit][STEP {self.global_steps}] Computing all metrics") - metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) - metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) + print(f"[Trainer.fit][WARN] Step {self.global_steps}: Actor update did not return metrics in meta_info.") + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Actor updated.") # DEBUG else: - missing_keys = [k for k in required_keys if k not in batch.batch] - missing_meta = [k for k in required_meta_keys if k not in batch.meta_info] - print(f"[Trainer.fit][STEP {self.global_steps}][ERROR] Cannot compute metrics due to missing keys: {missing_keys}, {missing_meta}") - # Log timing separately if main metrics can't be computed - metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) - except KeyError as e: - print(f"[Trainer.fit][STEP {self.global_steps}][ERROR] Metrics calculation failed due to KeyError: {e}") - except Exception as e: # Catch other potential errors during metric calculation - print(f"[Trainer.fit][STEP {self.global_steps}][ERROR] Error during metric calculation: {e}") - import traceback - traceback.print_exc() - - # Log metrics - print(f"[Trainer.fit][STEP {self.global_steps}] Logging metrics to tracking system") - try: - logger.log(data=metrics, step=self.global_steps) - except Exception as e: - print(f"[Trainer.fit][STEP {self.global_steps}][ERROR] Error logging metrics: {e}") - - print(f"[Trainer.fit][STEP {self.global_steps}] Completed step {self.global_steps}") - self.global_steps += 1 + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Skipping actor update (critic warmup phase).") # DEBUG + + # --- Validate --- + if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and \ + self.global_steps % self.config.trainer.test_freq == 0: + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Starting validation...") # DEBUG + with _timer('testing', timing_raw): + val_metrics: dict = self._validate() + metrics.update(val_metrics) + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Validation finished.") # DEBUG + + # --- Save Checkpoint --- + if self.config.trainer.save_freq > 0 and \ + self.global_steps % self.config.trainer.save_freq == 0: + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Saving checkpoint...") # DEBUG + with _timer('save_checkpoint', timing_raw): + self._save_checkpoint() + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Checkpoint saved.") # DEBUG - if self.global_steps >= self.total_training_steps: - print(f"[Trainer.fit] Reached total training steps ({self.total_training_steps}). Exiting.") - return + # --- Collect Metrics --- + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Collecting and computing metrics...") # DEBUG + metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) + metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Metrics computed: {list(metrics.keys())}") # DEBUG - print(f"[Trainer.fit] Completed epoch {epoch}") - - print(f"[Trainer.fit] Training complete") + # --- Log Metrics --- + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Logging metrics...") # DEBUG + logger.log(data=metrics, step=self.global_steps) + print(f"[Trainer.fit][DEBUG] Step {self.global_steps}: Metrics logged.") # DEBUG + + self.global_steps += 1 + + if self.config.trainer.total_training_steps is not None and self.global_steps >= self.config.trainer.total_training_steps: + print(f"[Trainer.fit][DEBUG] Reached total training steps ({self.config.trainer.total_training_steps}). Exiting training loop.") # DEBUG + # perform validation after training + if self.val_reward_fn is not None: + print(f"[Trainer.fit][DEBUG] Performing final validation...") # DEBUG + val_metrics = self._validate() + pprint(f'Final validation metrics: {val_metrics}') + logger.log(data=val_metrics, step=self.global_steps) + print(f"[Trainer.fit][DEBUG] Final validation logged.") # DEBUG + return + print(f"[Trainer.fit][DEBUG] Finished Epoch {epoch}") # DEBUG + print(f"[Trainer.fit][DEBUG] Training loop finished after {self.config.trainer.total_epochs} epochs.") # DEBUG def _create_loss_mask(self, batch, metrics): """Create loss mask for state tokens.""" diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index fb16f8e3..4631fb24 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -32,7 +32,6 @@ import verl.utils.torch_functional as verl_F from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis -import tensordict __all__ = ['DataParallelPPOActor'] @@ -62,17 +61,21 @@ def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, entropy: # (bs, response_len) log_probs: # (bs, response_len) """ + print(f"[DP_Actor._forward_micro_batch] Entered. use_remove_padding={self.use_remove_padding}, use_ulysses_sp={self.use_ulysses_sp}") response_length = micro_batch['responses'].size(-1) with torch.autocast(device_type='cuda', dtype=torch.bfloat16): input_ids = micro_batch['input_ids'] batch_size, seqlen = input_ids.shape attention_mask = micro_batch['attention_mask'] position_ids = micro_batch['position_ids'] + print(f"[DP_Actor._forward_micro_batch] input_ids device: {input_ids.device}, shape: {input_ids.shape}") if self.use_remove_padding: + print(f"[DP_Actor._forward_micro_batch] Using remove_padding.") input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...) input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) + print(f"[DP_Actor._forward_micro_batch] input_ids_rmpad shape after unpad: {input_ids_rmpad.shape}") # unpad the position_ids to align the rotary position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), @@ -83,11 +86,13 @@ def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, # pad and slice the inputs if sp > 1 if self.use_ulysses_sp: + print(f"[DP_Actor._forward_micro_batch] Using Ulysses SP. SP size: {self.ulysses_sequence_parallel_size}") input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(input_ids_rmpad, \ position_ids_rmpad, \ sp_size=self.ulysses_sequence_parallel_size) input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs(input_ids_rmpad_rolled, None, self.ulysses_sequence_parallel_size) + print(f"[DP_Actor._forward_micro_batch] input_ids_rmpad shape after SP slice: {input_ids_rmpad.shape}, pad_size: {pad_size}") input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) # ((total_nnz / sp) + pad) @@ -97,6 +102,7 @@ def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, position_ids=position_ids_rmpad, use_cache=False) # prevent model thinks we are generating logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size) + print(f"[DP_Actor._forward_micro_batch] logits_rmpad device: {logits_rmpad.device}, shape: {logits_rmpad.shape}") logits_rmpad.div_(temperature) @@ -108,6 +114,7 @@ def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, # gather log_prob if sp > 1 if self.use_ulysses_sp: + print(f"[DP_Actor._forward_micro_batch] Gathering outputs for SP.") # gather and unpad for the ulysses sp log_probs = gather_outpus_and_unpad(log_probs, gather_dim=0, unpad_dim=0, padding_size=pad_size) entropy_rmpad = gather_outpus_and_unpad(entropy_rmpad, @@ -123,32 +130,41 @@ def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, indices=indices, batch=batch_size, seqlen=seqlen) + print(f"[DP_Actor._forward_micro_batch] full_log_probs shape after pad_input: {full_log_probs.shape}") # only return response part: entropy = full_entropy.squeeze(-1)[:, -response_length - 1:-1] # (bsz, response_length) log_probs = full_log_probs.squeeze(-1)[:, -response_length - 1:-1] # (bsz, response_length) else: # not using rmpad and no ulysses sp + print(f"[DP_Actor._forward_micro_batch] Not using remove_padding.") output = self.actor_module(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False) # prevent model thinks we are generating logits = output.logits + print(f"[DP_Actor._forward_micro_batch] logits device: {logits.device}, shape: {logits.shape}") logits.div_(temperature) logits = logits[:, -response_length - 1:-1] # (bsz, response_length) log_probs = logprobs_from_logits(logits, micro_batch['responses']) entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length) - + print(f"[DP_Actor._forward_micro_batch] log_probs shape: {log_probs.shape}, entropy shape: {entropy.shape}") + + print(f"[DP_Actor._forward_micro_batch] Exiting.") return entropy, log_probs def _optimizer_step(self): + print(f"[DP_Actor._optimizer_step] Entered.") assert self.config.grad_clip is not None if isinstance(self.actor_module, FSDP): + print(f"[DP_Actor._optimizer_step] Clipping grad norm for FSDP module.") grad_norm = self.actor_module.clip_grad_norm_(max_norm=self.config.grad_clip) else: + print(f"[DP_Actor._optimizer_step] Clipping grad norm for standard module.") grad_norm = torch.nn.utils.clip_grad_norm_(self.actor_module.parameters(), max_norm=self.config.grad_clip) self.actor_optimizer.step() + print(f"[DP_Actor._optimizer_step] Optimizer step done. Grad norm: {grad_norm}. Exiting.") return grad_norm def compute_log_prob(self, data: DataProto) -> torch.Tensor: @@ -169,93 +185,60 @@ def compute_log_prob(self, data: DataProto) -> torch.Tensor: Returns: torch.Tensor: the log_prob tensor """ + print(f"[DP_Actor.compute_log_prob] Entered. Data meta_info: {data.meta_info}") # set to eval + print(f"[DP_Actor.compute_log_prob] Setting actor_module to eval mode.") self.actor_module.eval() + print(f"[DP_Actor.compute_log_prob] actor_module is in eval mode: {not self.actor_module.training}") micro_batch_size = data.meta_info['micro_batch_size'] temperature = data.meta_info['temperature'] # temperature must be in the data.meta_info to avoid slient error use_dynamic_bsz = data.meta_info['use_dynamic_bsz'] select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids'] - - # --- DEBUG: Log device before select --- - original_device = 'Unknown' - if 'input_ids' in data.batch: - original_device = data.batch['input_ids'].device - print(f"[DP_Actor.compute_log_prob] Original data device: {original_device}") - batch = data.select(batch_keys=select_keys).batch - - # --- DEBUG: Log device after select --- - select_device = 'Unknown' - if 'input_ids' in batch: - select_device = batch['input_ids'].device - print(f"[DP_Actor.compute_log_prob] Device after select: {select_device}") - - # Move data to CPU for splitting - print(f"[DP_Actor.compute_log_prob] Moving batch to CPU before split") - batch_cpu_dict = {k: v.to('cpu') if isinstance(v, torch.Tensor) else v for k, v in batch.items()} - batch_cpu = tensordict.TensorDict(source=batch_cpu_dict, batch_size=batch.batch_size) - print(f"[DP_Actor.compute_log_prob] Created TensorDict on CPU with batch_size={batch_cpu.batch_size}") + print(f"[DP_Actor.compute_log_prob] Selected batch keys. input_ids shape: {batch['input_ids'].shape}, responses shape: {batch['responses'].shape}") if use_dynamic_bsz: # split using dynamic bsz max_token_len = data.meta_info['max_token_len'] * self.ulysses_sequence_parallel_size - print(f"[DP_Actor.compute_log_prob] Using dynamic batch size with max_token_len={max_token_len}") - micro_batches, indices = rearrange_micro_batches(batch=batch_cpu, max_token_len=max_token_len) + print(f"[DP_Actor.compute_log_prob] Using dynamic batch size. max_token_len (incl. SP): {max_token_len}") + micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len) else: - print(f"[DP_Actor.compute_log_prob] Using fixed batch size with micro_batch_size={micro_batch_size}") - micro_batches = batch_cpu.split(micro_batch_size) + print(f"[DP_Actor.compute_log_prob] Using fixed micro_batch_size: {micro_batch_size}") + micro_batches = batch.split(micro_batch_size) log_probs_lst = [] - for mb_idx, micro_batch in enumerate(micro_batches): - # --- DEBUG: Log micro_batch device before moving to CUDA --- - mb_device = 'Unknown' - if 'input_ids' in micro_batch: - mb_device = micro_batch['input_ids'].device - print(f"[DP_Actor.compute_log_prob] Micro-batch {mb_idx} device before potential CUDA move: {mb_device}") - - # Conditionally move to CUDA if available - target_device = torch.device(f'cuda:{torch.cuda.current_device()}') if torch.cuda.is_available() else torch.device('cpu') - needs_move = False - if 'input_ids' in micro_batch: - if micro_batch['input_ids'].device != target_device: - needs_move = True - - if needs_move and torch.cuda.is_available(): - print(f"[DP_Actor.compute_log_prob] Moving micro-batch {mb_idx} to {target_device}") - micro_batch = micro_batch.to(target_device) - # --- DEBUG: Log micro_batch device after moving to CUDA --- - after_device = 'Unknown' - if 'input_ids' in micro_batch: - after_device = micro_batch['input_ids'].device - print(f"[DP_Actor.compute_log_prob] Micro-batch {mb_idx} device after move: {after_device}") - + print(f"[DP_Actor.compute_log_prob] Starting micro-batch loop for {len(micro_batches)} micro-batches.") + for i, micro_batch_data in enumerate(micro_batches): + print(f"[DP_Actor.compute_log_prob] Processing micro-batch {i+1}/{len(micro_batches)}. Device of input_ids: {micro_batch_data['input_ids'].device}") with torch.no_grad(): - _, log_probs = self._forward_micro_batch(micro_batch, temperature=temperature) - # --- DEBUG: Log log_probs device --- - print(f"[DP_Actor.compute_log_prob] Log probs device for micro-batch {mb_idx}: {log_probs.device}") + _, log_probs = self._forward_micro_batch(micro_batch_data, temperature=temperature) log_probs_lst.append(log_probs) - - print(f"[DP_Actor.compute_log_prob] Concatenating {len(log_probs_lst)} micro-batches") + print(f"[DP_Actor.compute_log_prob] Micro-batch loop finished.") log_probs = torch.concat(log_probs_lst, dim=0) - print(f"[DP_Actor.compute_log_prob] Concatenated log_probs device: {log_probs.device}") if use_dynamic_bsz: + print(f"[DP_Actor.compute_log_prob] Reverting dynamic batch size ordering.") indices = list(itertools.chain.from_iterable(indices)) assert len(indices) == log_probs.size(0), f"{len(indices)} vs. {log_probs.size()}" - revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long, device=log_probs.device) + revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) log_probs = log_probs[revert_indices] - + + print(f"[DP_Actor.compute_log_prob] Exiting. Final log_probs shape: {log_probs.shape}") return log_probs def update_policy(self, data: DataProto): + print(f"[DP_Actor.update_policy] Entered. Data meta_info: {data.meta_info}") # make sure we are in training mode + print(f"[DP_Actor.update_policy] Setting actor_module to train mode.") self.actor_module.train() + print(f"[DP_Actor.update_policy] actor_module is in train mode: {self.actor_module.training}") assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size == 0 self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size temperature = data.meta_info['temperature'] # temperature must be in the data.meta_info to avoid slient error + print(f"[DP_Actor.update_policy] Grad accumulation: {self.gradient_accumulation}, Temp: {temperature}") select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids', 'old_log_probs', 'advantages'] if self.config.state_masking: @@ -263,77 +246,55 @@ def update_policy(self, data: DataProto): if self.config.use_kl_loss: select_keys.append('ref_log_prob') batch = data.select(batch_keys=select_keys).batch - - # --- DEBUG and FIX: Log device information and move data to CPU before split --- - print(f"[DP_Actor.update_policy] Moving batch to CPU before split") - batch_device = 'Unknown' - if 'input_ids' in batch: - batch_device = batch['input_ids'].device - print(f"[DP_Actor.update_policy] Device BEFORE move to CPU: {batch_device}") - - # Fix: First create a dictionary with CPU tensors, then create a TensorDict - batch_cpu_dict = {k: v.to('cpu') if isinstance(v, torch.Tensor) else v for k, v in batch.items()} - batch_cpu = tensordict.TensorDict(source=batch_cpu_dict, batch_size=batch.batch_size) - print(f"[DP_Actor.update_policy] Created TensorDict on CPU with batch_size={batch_cpu.batch_size}") + print(f"[DP_Actor.update_policy] Selected batch keys for training. input_ids shape: {batch['input_ids'].shape}") # Split to make minibatch iterator for updating the actor # See PPO paper for details. https://arxiv.org/abs/1707.06347 - print(f"[DP_Actor.update_policy] Device for split: cpu") - dataloader = batch_cpu.split(self.config.ppo_mini_batch_size) - print(f"[DP_Actor.update_policy] Dataloader created after split") + dataloader = batch.split(self.config.ppo_mini_batch_size) + print(f"[DP_Actor.update_policy] Created dataloader for {len(dataloader)} mini-batches.") metrics = {} - for batch_idx, data in enumerate(dataloader): - # --- DEBUG: Log mini-batch device --- - mb_device = 'Unknown' - if 'input_ids' in data: - mb_device = data['input_ids'].device - print(f"[DP_Actor.update_policy] Mini-batch {batch_idx} device: {mb_device}") - + for batch_idx, mini_batch_data_container in enumerate(dataloader): + print(f"[DP_Actor.update_policy] Processing mini-batch {batch_idx+1}/{len(dataloader)}.") # split batch into micro_batches - mini_batch = data + # mini_batch = mini_batch_data_container # already a TensorDict from split if self.config.use_dynamic_bsz: max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size - micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len) + print(f"[DP_Actor.update_policy] Using dynamic micro-batch for mini-batch {batch_idx+1}. max_token_len: {max_token_len}") + micro_batches, _ = rearrange_micro_batches(batch=mini_batch_data_container, max_token_len=max_token_len) else: # split batch into micro_batches - micro_batches = mini_batch.split(self.config.ppo_micro_batch_size) - + fixed_micro_batch_size = self.config.ppo_micro_batch_size + print(f"[DP_Actor.update_policy] Using fixed micro-batch size for mini-batch {batch_idx+1}: {fixed_micro_batch_size}") + micro_batches = mini_batch_data_container.split(fixed_micro_batch_size) + + print(f"[DP_Actor.update_policy] Mini-batch {batch_idx+1} split into {len(micro_batches)} micro-batches.") self.actor_optimizer.zero_grad() - - for micro_batch_idx, data in enumerate(micro_batches): - # --- DEBUG: Log micro-batch device before moving to CUDA --- - before_cuda_device = 'Unknown' - if 'input_ids' in data: - before_cuda_device = data['input_ids'].device - print(f"[DP_Actor.update_policy] Micro-batch {batch_idx}-{micro_batch_idx} device BEFORE .cuda(): {before_cuda_device}") + print(f"[DP_Actor.update_policy] Optimizer zero_grad done for mini-batch {batch_idx+1}.") + + for i, micro_batch_data in enumerate(micro_batches): + print(f"[DP_Actor.update_policy] Forward/Backward for micro-batch {i+1}/{len(micro_batches)} of mini-batch {batch_idx+1}.") + # Ensure data is on CUDA for the forward pass, FSDP handles sharding. + # Verl's default behavior might keep it on CPU if offloading is used, then FSDP moves shards. + # For direct model call, it must be on the device FSDP expects for its root module or where computation occurs. + # If not using FSDP, or if FSDP is parameter-only offload, this explicit .cuda() is important. + micro_batch_data_cuda = micro_batch_data.cuda() + print(f"[DP_Actor.update_policy] Micro-batch {i+1} input_ids device: {micro_batch_data_cuda['input_ids'].device}") - # Conditionally move data to CUDA - if torch.cuda.is_available(): - data = data.cuda() # actor device is cpu when using offload - - # --- DEBUG: Log micro-batch device after moving to CUDA --- - after_cuda_device = 'Unknown' - if 'input_ids' in data: - after_cuda_device = data['input_ids'].device - print(f"[DP_Actor.update_policy] Micro-batch {batch_idx}-{micro_batch_idx} device AFTER .cuda(): {after_cuda_device}") - else: - print(f"[DP_Actor.update_policy] CUDA not available, staying on CPU") - - responses = data['responses'] + responses = micro_batch_data_cuda['responses'] response_length = responses.size(1) - attention_mask = data['attention_mask'] + attention_mask = micro_batch_data_cuda['attention_mask'] response_mask = attention_mask[:, -response_length:] if self.config.state_masking: - response_mask = data['loss_mask'] - old_log_prob = data['old_log_probs'] - advantages = data['advantages'] + response_mask = micro_batch_data_cuda['loss_mask'] + old_log_prob = micro_batch_data_cuda['old_log_probs'] + advantages = micro_batch_data_cuda['advantages'] clip_ratio = self.config.clip_ratio entropy_coeff = self.config.entropy_coeff # all return: (bsz, response_length) - entropy, log_prob = self._forward_micro_batch(micro_batch=data, temperature=temperature) + entropy, log_prob = self._forward_micro_batch(micro_batch=micro_batch_data_cuda, temperature=temperature) pg_loss, pg_clipfrac, ppo_kl = core_algos.compute_policy_loss(old_log_prob=old_log_prob, log_prob=log_prob, @@ -345,9 +306,10 @@ def update_policy(self, data: DataProto): # compute policy loss policy_loss = pg_loss - entropy_loss * entropy_coeff + print(f"[DP_Actor.update_policy] Micro-batch {i+1} losses: pg_loss={pg_loss.item():.4f}, entropy_loss={entropy_loss.item():.4f}, policy_loss={policy_loss.item():.4f}") if self.config.use_kl_loss: - ref_log_prob = data['ref_log_prob'] + ref_log_prob = micro_batch_data_cuda['ref_log_prob'] # compute kl loss kld = core_algos.kl_penalty(logprob=log_prob, ref_logprob=ref_log_prob, @@ -357,20 +319,26 @@ def update_policy(self, data: DataProto): policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef metrics['actor/kl_loss'] = kl_loss.detach().item() metrics['actor/kl_coef'] = self.config.kl_loss_coef + print(f"[DP_Actor.update_policy] Micro-batch {i+1} KL loss: {kl_loss.item():.4f}, updated policy_loss: {policy_loss.item():.4f}") loss = policy_loss / self.gradient_accumulation + print(f"[DP_Actor.update_policy] Micro-batch {i+1} final loss (for backward): {loss.item():.4f}") loss.backward() + print(f"[DP_Actor.update_policy] Micro-batch {i+1} backward pass done.") - data = { + current_metrics_data = { 'actor/entropy_loss': entropy_loss.detach().item(), 'actor/pg_loss': pg_loss.detach().item(), 'actor/pg_clipfrac': pg_clipfrac.detach().item(), 'actor/ppo_kl': ppo_kl.detach().item(), } - append_to_dict(metrics, data) + append_to_dict(metrics, current_metrics_data) grad_norm = self._optimizer_step() - data = {'actor/grad_norm': grad_norm.detach().item()} - append_to_dict(metrics, data) + optimizer_step_metrics = {'actor/grad_norm': grad_norm.detach().item()} + append_to_dict(metrics, optimizer_step_metrics) + print(f"[DP_Actor.update_policy] Optimizer step done for mini-batch {batch_idx+1}. Grad norm: {grad_norm.item():.4f}") + self.actor_optimizer.zero_grad() + print(f"[DP_Actor.update_policy] Final optimizer zero_grad. Exiting. Metrics: {metrics}") return metrics diff --git a/verl/workers/actor/dp_actor_changed.py b/verl/workers/actor/dp_actor_changed.py new file mode 100644 index 00000000..fb16f8e3 --- /dev/null +++ b/verl/workers/actor/dp_actor_changed.py @@ -0,0 +1,376 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Single Process Actor +""" + +import itertools +from typing import Iterable, Tuple + +import torch +from torch import nn +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + +from verl import DataProto +from verl.trainer.ppo import core_algos +from verl.workers.actor import BasePPOActor +from verl.utils.py_functional import append_to_dict +from verl.utils.torch_functional import logprobs_from_logits, masked_mean +from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad +from verl.utils.seqlen_balancing import rearrange_micro_batches, get_reverse_idx +import verl.utils.torch_functional as verl_F + +from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis +import tensordict + +__all__ = ['DataParallelPPOActor'] + + +class DataParallelPPOActor(BasePPOActor): + + def __init__( + self, + config, + actor_module: nn.Module, + actor_optimizer: torch.optim.Optimizer = None, + ): + """When optimizer is None, it is Reference Policy""" + super().__init__(config) + self.actor_module = actor_module + self.actor_optimizer = actor_optimizer + self.use_remove_padding = self.config.get('use_remove_padding', False) + print(f'Actor use_remove_padding={self.use_remove_padding}') + self.ulysses_sequence_parallel_size = self.config.ulysses_sequence_parallel_size + self.use_ulysses_sp = self.ulysses_sequence_parallel_size > 1 + + self.compute_entropy_from_logits = torch.compile(verl_F.entropy_from_logits, dynamic=True) + + def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Returns: + entropy: # (bs, response_len) + log_probs: # (bs, response_len) + """ + response_length = micro_batch['responses'].size(-1) + with torch.autocast(device_type='cuda', dtype=torch.bfloat16): + input_ids = micro_batch['input_ids'] + batch_size, seqlen = input_ids.shape + attention_mask = micro_batch['attention_mask'] + position_ids = micro_batch['position_ids'] + + if self.use_remove_padding: + input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), + attention_mask) # input_ids_rmpad (total_nnz, ...) + input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) + + # unpad the position_ids to align the rotary + position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), + indices).transpose(0, 1) + + # for compute the log_prob + input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz) + + # pad and slice the inputs if sp > 1 + if self.use_ulysses_sp: + input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(input_ids_rmpad, \ + position_ids_rmpad, \ + sp_size=self.ulysses_sequence_parallel_size) + input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs(input_ids_rmpad_rolled, None, + self.ulysses_sequence_parallel_size) + + input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) # ((total_nnz / sp) + pad) + + # only pass input_ids and position_ids to enable flash_attn_varlen + output = self.actor_module(input_ids=input_ids_rmpad, + attention_mask=None, + position_ids=position_ids_rmpad, + use_cache=False) # prevent model thinks we are generating + logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size) + + logits_rmpad.div_(temperature) + + # compute entropy + entropy_rmpad = self.compute_entropy_from_logits(logits_rmpad) # ((total_nnz / sp) + pad) + + # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen) + log_probs = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled) + + # gather log_prob if sp > 1 + if self.use_ulysses_sp: + # gather and unpad for the ulysses sp + log_probs = gather_outpus_and_unpad(log_probs, gather_dim=0, unpad_dim=0, padding_size=pad_size) + entropy_rmpad = gather_outpus_and_unpad(entropy_rmpad, + gather_dim=0, + unpad_dim=0, + padding_size=pad_size) + # pad back to (bsz, seqlen) + full_entropy = pad_input(hidden_states=entropy_rmpad.unsqueeze(-1), + indices=indices, + batch=batch_size, + seqlen=seqlen) + full_log_probs = pad_input(hidden_states=log_probs.unsqueeze(-1), + indices=indices, + batch=batch_size, + seqlen=seqlen) + + # only return response part: + entropy = full_entropy.squeeze(-1)[:, -response_length - 1:-1] # (bsz, response_length) + log_probs = full_log_probs.squeeze(-1)[:, -response_length - 1:-1] # (bsz, response_length) + + else: # not using rmpad and no ulysses sp + output = self.actor_module(input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=False) # prevent model thinks we are generating + logits = output.logits + logits.div_(temperature) + logits = logits[:, -response_length - 1:-1] # (bsz, response_length) + log_probs = logprobs_from_logits(logits, micro_batch['responses']) + entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length) + + return entropy, log_probs + + def _optimizer_step(self): + assert self.config.grad_clip is not None + + if isinstance(self.actor_module, FSDP): + grad_norm = self.actor_module.clip_grad_norm_(max_norm=self.config.grad_clip) + else: + grad_norm = torch.nn.utils.clip_grad_norm_(self.actor_module.parameters(), max_norm=self.config.grad_clip) + self.actor_optimizer.step() + return grad_norm + + def compute_log_prob(self, data: DataProto) -> torch.Tensor: + """Compute the log probability of the responses given input_ids, attention_mask and position_ids + + Args: + data (DataProto): a DataProto containing keys + + ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the + concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``. + + ``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64. + + ``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. + + ``responses``: tensor of shape [batch_size, response_length]. torch.int64. + + Returns: + torch.Tensor: the log_prob tensor + """ + # set to eval + self.actor_module.eval() + + micro_batch_size = data.meta_info['micro_batch_size'] + temperature = data.meta_info['temperature'] # temperature must be in the data.meta_info to avoid slient error + use_dynamic_bsz = data.meta_info['use_dynamic_bsz'] + + select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids'] + + # --- DEBUG: Log device before select --- + original_device = 'Unknown' + if 'input_ids' in data.batch: + original_device = data.batch['input_ids'].device + print(f"[DP_Actor.compute_log_prob] Original data device: {original_device}") + + batch = data.select(batch_keys=select_keys).batch + + # --- DEBUG: Log device after select --- + select_device = 'Unknown' + if 'input_ids' in batch: + select_device = batch['input_ids'].device + print(f"[DP_Actor.compute_log_prob] Device after select: {select_device}") + + # Move data to CPU for splitting + print(f"[DP_Actor.compute_log_prob] Moving batch to CPU before split") + batch_cpu_dict = {k: v.to('cpu') if isinstance(v, torch.Tensor) else v for k, v in batch.items()} + batch_cpu = tensordict.TensorDict(source=batch_cpu_dict, batch_size=batch.batch_size) + print(f"[DP_Actor.compute_log_prob] Created TensorDict on CPU with batch_size={batch_cpu.batch_size}") + + if use_dynamic_bsz: + # split using dynamic bsz + max_token_len = data.meta_info['max_token_len'] * self.ulysses_sequence_parallel_size + print(f"[DP_Actor.compute_log_prob] Using dynamic batch size with max_token_len={max_token_len}") + micro_batches, indices = rearrange_micro_batches(batch=batch_cpu, max_token_len=max_token_len) + else: + print(f"[DP_Actor.compute_log_prob] Using fixed batch size with micro_batch_size={micro_batch_size}") + micro_batches = batch_cpu.split(micro_batch_size) + + log_probs_lst = [] + for mb_idx, micro_batch in enumerate(micro_batches): + # --- DEBUG: Log micro_batch device before moving to CUDA --- + mb_device = 'Unknown' + if 'input_ids' in micro_batch: + mb_device = micro_batch['input_ids'].device + print(f"[DP_Actor.compute_log_prob] Micro-batch {mb_idx} device before potential CUDA move: {mb_device}") + + # Conditionally move to CUDA if available + target_device = torch.device(f'cuda:{torch.cuda.current_device()}') if torch.cuda.is_available() else torch.device('cpu') + needs_move = False + if 'input_ids' in micro_batch: + if micro_batch['input_ids'].device != target_device: + needs_move = True + + if needs_move and torch.cuda.is_available(): + print(f"[DP_Actor.compute_log_prob] Moving micro-batch {mb_idx} to {target_device}") + micro_batch = micro_batch.to(target_device) + # --- DEBUG: Log micro_batch device after moving to CUDA --- + after_device = 'Unknown' + if 'input_ids' in micro_batch: + after_device = micro_batch['input_ids'].device + print(f"[DP_Actor.compute_log_prob] Micro-batch {mb_idx} device after move: {after_device}") + + with torch.no_grad(): + _, log_probs = self._forward_micro_batch(micro_batch, temperature=temperature) + # --- DEBUG: Log log_probs device --- + print(f"[DP_Actor.compute_log_prob] Log probs device for micro-batch {mb_idx}: {log_probs.device}") + log_probs_lst.append(log_probs) + + print(f"[DP_Actor.compute_log_prob] Concatenating {len(log_probs_lst)} micro-batches") + log_probs = torch.concat(log_probs_lst, dim=0) + print(f"[DP_Actor.compute_log_prob] Concatenated log_probs device: {log_probs.device}") + + if use_dynamic_bsz: + indices = list(itertools.chain.from_iterable(indices)) + assert len(indices) == log_probs.size(0), f"{len(indices)} vs. {log_probs.size()}" + revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long, device=log_probs.device) + log_probs = log_probs[revert_indices] + + return log_probs + + def update_policy(self, data: DataProto): + # make sure we are in training mode + self.actor_module.train() + + assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size == 0 + self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size + temperature = data.meta_info['temperature'] # temperature must be in the data.meta_info to avoid slient error + + select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids', 'old_log_probs', 'advantages'] + if self.config.state_masking: + select_keys.append('loss_mask') + if self.config.use_kl_loss: + select_keys.append('ref_log_prob') + batch = data.select(batch_keys=select_keys).batch + + # --- DEBUG and FIX: Log device information and move data to CPU before split --- + print(f"[DP_Actor.update_policy] Moving batch to CPU before split") + batch_device = 'Unknown' + if 'input_ids' in batch: + batch_device = batch['input_ids'].device + print(f"[DP_Actor.update_policy] Device BEFORE move to CPU: {batch_device}") + + # Fix: First create a dictionary with CPU tensors, then create a TensorDict + batch_cpu_dict = {k: v.to('cpu') if isinstance(v, torch.Tensor) else v for k, v in batch.items()} + batch_cpu = tensordict.TensorDict(source=batch_cpu_dict, batch_size=batch.batch_size) + print(f"[DP_Actor.update_policy] Created TensorDict on CPU with batch_size={batch_cpu.batch_size}") + + # Split to make minibatch iterator for updating the actor + # See PPO paper for details. https://arxiv.org/abs/1707.06347 + print(f"[DP_Actor.update_policy] Device for split: cpu") + dataloader = batch_cpu.split(self.config.ppo_mini_batch_size) + print(f"[DP_Actor.update_policy] Dataloader created after split") + + metrics = {} + for batch_idx, data in enumerate(dataloader): + # --- DEBUG: Log mini-batch device --- + mb_device = 'Unknown' + if 'input_ids' in data: + mb_device = data['input_ids'].device + print(f"[DP_Actor.update_policy] Mini-batch {batch_idx} device: {mb_device}") + + # split batch into micro_batches + mini_batch = data + if self.config.use_dynamic_bsz: + max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size + micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len) + else: + # split batch into micro_batches + micro_batches = mini_batch.split(self.config.ppo_micro_batch_size) + + self.actor_optimizer.zero_grad() + + for micro_batch_idx, data in enumerate(micro_batches): + # --- DEBUG: Log micro-batch device before moving to CUDA --- + before_cuda_device = 'Unknown' + if 'input_ids' in data: + before_cuda_device = data['input_ids'].device + print(f"[DP_Actor.update_policy] Micro-batch {batch_idx}-{micro_batch_idx} device BEFORE .cuda(): {before_cuda_device}") + + # Conditionally move data to CUDA + if torch.cuda.is_available(): + data = data.cuda() # actor device is cpu when using offload + + # --- DEBUG: Log micro-batch device after moving to CUDA --- + after_cuda_device = 'Unknown' + if 'input_ids' in data: + after_cuda_device = data['input_ids'].device + print(f"[DP_Actor.update_policy] Micro-batch {batch_idx}-{micro_batch_idx} device AFTER .cuda(): {after_cuda_device}") + else: + print(f"[DP_Actor.update_policy] CUDA not available, staying on CPU") + + responses = data['responses'] + response_length = responses.size(1) + attention_mask = data['attention_mask'] + response_mask = attention_mask[:, -response_length:] + if self.config.state_masking: + response_mask = data['loss_mask'] + old_log_prob = data['old_log_probs'] + advantages = data['advantages'] + + clip_ratio = self.config.clip_ratio + entropy_coeff = self.config.entropy_coeff + + # all return: (bsz, response_length) + entropy, log_prob = self._forward_micro_batch(micro_batch=data, temperature=temperature) + + pg_loss, pg_clipfrac, ppo_kl = core_algos.compute_policy_loss(old_log_prob=old_log_prob, + log_prob=log_prob, + advantages=advantages, + eos_mask=response_mask, + cliprange=clip_ratio) + # compute entropy loss from entropy + entropy_loss = verl_F.masked_mean(entropy, response_mask) + + # compute policy loss + policy_loss = pg_loss - entropy_loss * entropy_coeff + + if self.config.use_kl_loss: + ref_log_prob = data['ref_log_prob'] + # compute kl loss + kld = core_algos.kl_penalty(logprob=log_prob, + ref_logprob=ref_log_prob, + kl_penalty=self.config.kl_loss_type) + kl_loss = masked_mean(kld, response_mask) + + policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef + metrics['actor/kl_loss'] = kl_loss.detach().item() + metrics['actor/kl_coef'] = self.config.kl_loss_coef + + loss = policy_loss / self.gradient_accumulation + loss.backward() + + data = { + 'actor/entropy_loss': entropy_loss.detach().item(), + 'actor/pg_loss': pg_loss.detach().item(), + 'actor/pg_clipfrac': pg_clipfrac.detach().item(), + 'actor/ppo_kl': ppo_kl.detach().item(), + } + append_to_dict(metrics, data) + + grad_norm = self._optimizer_step() + data = {'actor/grad_norm': grad_norm.detach().item()} + append_to_dict(metrics, data) + self.actor_optimizer.zero_grad() + return metrics diff --git a/verl/workers/critic/dp_critic.py b/verl/workers/critic/dp_critic.py index e6f3c8ec..e68a3162 100644 --- a/verl/workers/critic/dp_critic.py +++ b/verl/workers/critic/dp_critic.py @@ -16,11 +16,11 @@ """ import itertools from typing import Iterable +from collections import defaultdict import torch import torch.distributed from torch import nn, optim -import tensordict from torch.distributed.fsdp import FullyShardedDataParallel as FSDP @@ -52,17 +52,21 @@ def __init__(self, config, critic_module: nn.Module, critic_optimizer: optim.Opt self.ulysses_sequence_parallel_size = self.config.get('ulysses_sequence_parallel_size', 1) def _forward_micro_batch(self, micro_batch): + print(f"[DP_Critic._forward_micro_batch] Entered. use_remove_padding={self.use_remove_padding}, use_ulysses_sp={self.ulysses_sequence_parallel_size > 1}") response_length = micro_batch['responses'].size(-1) with torch.autocast(device_type='cuda', dtype=torch.bfloat16): input_ids = micro_batch['input_ids'] batch, seqlen = input_ids.shape attention_mask = micro_batch['attention_mask'] position_ids = micro_batch['position_ids'] + print(f"[DP_Critic._forward_micro_batch] input_ids device: {input_ids.device}, shape: {input_ids.shape}") if self.use_remove_padding: + print(f"[DP_Critic._forward_micro_batch] Using remove_padding.") input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...) input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) + print(f"[DP_Critic._forward_micro_batch] input_ids_rmpad shape after unpad: {input_ids_rmpad.shape}") # unpad the position_ids to align the rotary position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), @@ -70,9 +74,11 @@ def _forward_micro_batch(self, micro_batch): # pad and slice the inputs if sp > 1 if self.ulysses_sequence_parallel_size > 1: + print(f"[DP_Critic._forward_micro_batch] Using Ulysses SP. SP size: {self.ulysses_sequence_parallel_size}") input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(input_ids_rmpad, \ position_ids_rmpad, \ sp_size=self.ulysses_sequence_parallel_size) + print(f"[DP_Critic._forward_micro_batch] input_ids_rmpad shape after SP slice: {input_ids_rmpad.shape}, pad_size: {pad_size}") # only pass input_ids and position_ids to enable flash_attn_varlen output = self.critic_module(input_ids=input_ids_rmpad, @@ -81,9 +87,11 @@ def _forward_micro_batch(self, micro_batch): use_cache=False) # prevent model thinks we are generating values_rmpad = output.logits values_rmpad = values_rmpad.squeeze(0) # (total_nnz) + print(f"[DP_Critic._forward_micro_batch] values_rmpad shape after model: {values_rmpad.shape}") # gather output if sp > 1 if self.ulysses_sequence_parallel_size > 1: + print(f"[DP_Critic._forward_micro_batch] Gathering outputs for SP.") values_rmpad = gather_outpus_and_unpad(values_rmpad, gather_dim=0, unpad_dim=0, @@ -91,231 +99,209 @@ def _forward_micro_batch(self, micro_batch): # pad it back values = pad_input(values_rmpad, indices=indices, batch=batch, seqlen=seqlen).squeeze(-1) + print(f"[DP_Critic._forward_micro_batch] values shape after pad_input: {values.shape}") + # Adjust slicing for critic: we need value for the state BEFORE each token in response values = values[:, -response_length - 1:-1] + print(f"[DP_Critic._forward_micro_batch] values shape after slicing for response: {values.shape}") else: + print(f"[DP_Critic._forward_micro_batch] Not using remove_padding.") output = self.critic_module(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False) # prevent model thinks we are generating values = output.logits + print(f"[DP_Critic._forward_micro_batch] values device: {values.device}, shape: {values.shape}") + # Adjust slicing for critic: we need value for the state BEFORE each token in response values = values[:, -response_length - 1:-1].squeeze(-1) + print(f"[DP_Critic._forward_micro_batch] values shape after slicing for response: {values.shape}") + + print(f"[DP_Critic._forward_micro_batch] Exiting.") return values def _optimizer_step(self): + print(f"[DP_Critic._optimizer_step] Entered.") assert self.config.grad_clip is not None if isinstance(self.critic_module, FSDP): + print(f"[DP_Critic._optimizer_step] Clipping grad norm for FSDP module.") grad_norm = self.critic_module.clip_grad_norm_(self.config.grad_clip) else: + print(f"[DP_Critic._optimizer_step] Clipping grad norm for standard module.") grad_norm = torch.nn.utils.clip_grad_norm_(self.critic_module.parameters(), max_norm=self.config.grad_clip) self.critic_optimizer.step() + print(f"[DP_Critic._optimizer_step] Optimizer step done. Grad norm: {grad_norm}. Exiting.") return grad_norm - def compute_values(self, data: DataProto) -> torch.Tensor: + def compute_values(self, data: DataProto): + print(f"[DP_Critic.compute_values] Entered. Data meta_info: {data.meta_info}") + # Assuming data.meta_info should contain 'micro_batch_size' and 'use_dynamic_bsz' + # These should be set by the trainer before calling this method. + if 'micro_batch_size' not in data.meta_info or 'use_dynamic_bsz' not in data.meta_info: + print("[DP_Critic.compute_values] WARNING: 'micro_batch_size' or 'use_dynamic_bsz' missing from meta_info! This might cause errors.") + # Assigning defaults here might mask the issue, but can prevent immediate crash + micro_batch_size = data.meta_info.get('micro_batch_size', 1) # Default to 1 if missing + use_dynamic_bsz = data.meta_info.get('use_dynamic_bsz', False) + data.meta_info['micro_batch_size'] = micro_batch_size + data.meta_info['use_dynamic_bsz'] = use_dynamic_bsz + else: + micro_batch_size = data.meta_info['micro_batch_size'] + use_dynamic_bsz = data.meta_info['use_dynamic_bsz'] + + print(f"[DP_Critic.compute_values] Setting critic_module to eval mode.") self.critic_module.eval() - micro_batch_size = data.meta_info['micro_batch_size'] - select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids'] - - # --- DEBUG: Record original data device --- - original_device = 'Unknown' - if 'input_ids' in data.batch: - original_device = data.batch['input_ids'].device - print(f"[DP_Critic.compute_values] Start - Original data device: {original_device}") + print(f"[DP_Critic.compute_values] critic_module is in eval mode: {not self.critic_module.training}") + select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids'] batch = data.select(batch_keys=select_keys).batch - - # --- DEBUG: Record device after select --- - select_device = 'Unknown' - if 'input_ids' in batch: - select_device = batch['input_ids'].device - print(f"[DP_Critic.compute_values] Device AFTER select: {select_device}") - - use_dynamic_bsz = data.meta_info['use_dynamic_bsz'] + print(f"[DP_Critic.compute_values] Selected batch keys. input_ids shape: {batch['input_ids'].shape}, responses shape: {batch['responses'].shape}") + + # Verl's default behavior might send tensors on CPU if FSDP offload is used. + # The forward pass needs data on the appropriate device. + # Let's check the device before splitting. + print(f"[DP_Critic.compute_values] Device BEFORE split: {batch['input_ids'].device}") + + # If tensors are not on CUDA, move them? FSDP might handle this automatically. + # For now, assume FSDP handles device placement for forward pass. - # --- FIX: Move data to CPU before split --- - print(f"[DP_Critic.compute_values] Moving batch to CPU before split") - - # Fix error: Use TensorDict constructor instead of plain dictionary - batch_cpu_dict = {k: v.to('cpu') if isinstance(v, torch.Tensor) else v for k, v in batch.items()} - batch_cpu = tensordict.TensorDict(source=batch_cpu_dict, batch_size=batch.batch_size) - print(f"[DP_Critic.compute_values] Created TensorDict on CPU with batch_size={batch_cpu.batch_size}") - if use_dynamic_bsz: - # split using dynamic bsz - max_token_len = data.meta_info['max_token_len'] * self.ulysses_sequence_parallel_size - print(f"[DP_Critic.compute_values] Using dynamic batch size with max_token_len={max_token_len}") - micro_batches, indices = rearrange_micro_batches(batch=batch_cpu, max_token_len=max_token_len) + max_token_len = data.meta_info.get('max_token_len', 2048) * self.ulysses_sequence_parallel_size # Default if missing + print(f"[DP_Critic.compute_values] Using dynamic batch size. max_token_len (incl. SP): {max_token_len}") + micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len) else: - print(f"[DP_Critic.compute_values] Using fixed batch size with micro_batch_size={micro_batch_size}") - micro_batches = batch_cpu.split(micro_batch_size) + # micro_batch_size might be 0 if not set correctly by trainer + if micro_batch_size <= 0: + print(f"[DP_Critic.compute_values] ERROR: micro_batch_size is {micro_batch_size}. Cannot split batch. Check trainer config.") + # Raise an error or return dummy data? + micro_batch_size = 1 + print(f"[DP_Critic.compute_values] Using fixed micro_batch_size: {micro_batch_size}") + micro_batches = batch.split(micro_batch_size) values_lst = [] - for mb_idx, micro_batch in enumerate(micro_batches): - # --- DEBUG: Record micro_batch device --- - mb_device = 'Unknown' - if 'input_ids' in micro_batch: - mb_device = micro_batch['input_ids'].device - print(f"[DP_Critic.compute_values] Micro-batch {mb_idx} device BEFORE potential .cuda(): {mb_device}") - - # Conditionally move to CUDA - target_device = torch.device(f'cuda:{torch.cuda.current_device()}') if torch.cuda.is_available() else torch.device('cpu') - needs_move = False - if 'input_ids' in micro_batch: - if micro_batch['input_ids'].device != target_device: - needs_move = True - - if needs_move and torch.cuda.is_available(): - print(f"[DP_Critic.compute_values] Moving micro-batch {mb_idx} to {target_device}") - micro_batch = micro_batch.to(target_device) - elif not torch.cuda.is_available(): - print(f"[DP_Critic.compute_values] WARNING: CUDA not available. Staying on CPU.") - else: - print(f"[DP_Critic.compute_values] Micro-batch {mb_idx} already on target device. Skipping move.") - - # --- DEBUG: Record device after move --- - after_mb_device = 'Unknown' - if 'input_ids' in micro_batch: - after_mb_device = micro_batch['input_ids'].device - print(f"[DP_Critic.compute_values] Micro-batch {mb_idx} device AFTER potential .cuda(): {after_mb_device}") - + print(f"[DP_Critic.compute_values] Starting micro-batch loop for {len(micro_batches)} micro-batches.") + for i, micro_batch_data in enumerate(micro_batches): + print(f"[DP_Critic.compute_values] Processing micro-batch {i+1}/{len(micro_batches)}. Device of input_ids: {micro_batch_data['input_ids'].device}") + # Move to GPU if needed for forward pass? Or assume FSDP handles? + # micro_batch_data = micro_batch_data.cuda() # Tentative with torch.no_grad(): - values = self._forward_micro_batch(micro_batch) - # --- DEBUG: Record values device --- - print(f"[DP_Critic.compute_values] Micro-batch {mb_idx} values device: {values.device}") + values = self._forward_micro_batch(micro_batch_data) values_lst.append(values) - - print(f"[DP_Critic.compute_values] Concatenating {len(values_lst)} micro-batches") + print(f"[DP_Critic.compute_values] Micro-batch loop finished.") values = torch.concat(values_lst, dim=0) - print(f"[DP_Critic.compute_values] Concatenated values device: {values.device}") - - responses = data.batch['responses'] - attention_mask = data.batch['attention_mask'] - response_length = responses.size(1) - - # Ensure values and attention_mask are on the same device - if values.device != attention_mask.device: - print(f"[DP_Critic.compute_values] Moving values from {values.device} to {attention_mask.device}") - values = values.to(attention_mask.device) - - values = values * attention_mask[:, -response_length - 1:-1] + print(f"[DP_Critic.compute_values] Concatenated values shape: {values.shape}") + + # No need to multiply by mask here as _forward_micro_batch slices correctly + # responses = data.batch['responses'] + # attention_mask = data.batch['attention_mask'] + # response_length = responses.size(1) + # # values = values * attention_mask[:, -response_length - 1:-1] # Masking done internally or not needed if sliced? if use_dynamic_bsz: + print(f"[DP_Critic.compute_values] Reverting dynamic batch size ordering.") indices = list(itertools.chain.from_iterable(indices)) assert len(indices) == values.size(0), f"{len(indices)} vs. {values.size()}" - revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long, device=values.device) + revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) values = values[revert_indices] - print(f"[DP_Critic.compute_values] Final values shape: {values.shape}, device: {values.device}") - return values + print(f"[DP_Critic.compute_values] Exiting. Final values shape: {values.shape}") + # The function signature in BasePPOCritic implies it should return a DataProto + # containing the values, not just the tensor. + output = DataProto.from_dict(tensors={'values': values}) + return output def update_critic(self, data: DataProto): + print(f"[DP_Critic.update_critic] Entered. Data meta_info: {data.meta_info}") # make sure we are in training mode + print(f"[DP_Critic.update_critic] Setting critic_module to train mode.") self.critic_module.train() + print(f"[DP_Critic.update_critic] critic_module is in train mode: {self.critic_module.training}") metrics = {} - # --- DEBUG: Log initial input data device --- - initial_device = 'Unknown' - if 'input_ids' in data.batch: - initial_device = data.batch['input_ids'].device - print(f"[DP_Critic.update_critic] Start - Input data device: {initial_device}") - # --- END DEBUG --- - select_keys = ['input_ids', 'responses', 'attention_mask', 'position_ids', 'values', 'returns'] - # --- DEBUG: Log device before select --- - print(f"[DP_Critic.update_critic] Device BEFORE select: {initial_device}") batch = data.select(batch_keys=select_keys).batch - # --- DEBUG: Log device after select --- - select_device = 'Unknown' - if 'input_ids' in batch: - select_device = batch['input_ids'].device - print(f"[DP_Critic.update_critic] Device AFTER select: {select_device}") - # --- END DEBUG --- + print(f"[DP_Critic.update_critic] Selected batch keys for training. input_ids shape: {batch['input_ids'].shape}") - # --- Key fix: Move data to CPU before split --- - print(f"[DP_Critic.update_critic] Moving batch to CPU before split") - # Fix error: Use TensorDict constructor instead of plain dictionary - batch_cpu_dict = {k: v.to('cpu') if isinstance(v, torch.Tensor) else v for k, v in batch.items()} - batch_cpu = tensordict.TensorDict(source=batch_cpu_dict, batch_size=batch.batch_size) - print(f"[DP_Critic.update_critic] Created TensorDict on CPU with batch_size={batch_cpu.batch_size}") - - # Split to make minibatch iterator for updating the actor - # See PPO paper for details. https://arxiv.org/abs/1707.06347 - # --- DEBUG: Log device before split --- - split_device = 'cpu' # Should be CPU at this point - print(f"[DP_Critic.update_critic] Device BEFORE split: {split_device}") - dataloader = batch_cpu.split(self.config.ppo_mini_batch_size) - # --- DEBUG: Log device after split (dataloader is iterator) --- - print(f"[DP_Critic.update_critic] Dataloader created after split") - # --- END DEBUG --- - - for batch_idx, data in enumerate(dataloader): - # --- DEBUG: Log mini_batch device before micro-batch split --- - mb_device = 'Unknown' - if 'input_ids' in data: - mb_device = data['input_ids'].device - print(f"[DP_Critic.update_critic] Mini-batch {batch_idx} device: {mb_device}") - # --- END DEBUG --- + current_actual_batch_size = batch['input_ids'].shape[0] + configured_critic_ppo_mini_batch_size = self.config.ppo_mini_batch_size + + dataloader = [] # Default to an empty dataloader + + if current_actual_batch_size == 0: + print(f"[DP_Critic.update_critic] Current batch size is 0. Skipping PPO updates for this batch.") + else: + # Determine the effective mini-batch size for splitting. + # It should not be larger than the current actual batch size. + # It also shouldn't be less than 1 if there's data. + effective_mini_batch_size_for_split = min(current_actual_batch_size, configured_critic_ppo_mini_batch_size) + + if effective_mini_batch_size_for_split < 1: # Should ideally not happen if current_actual_batch_size > 0 + print(f"[DP_Critic.update_critic] Warning: effective_mini_batch_size_for_split calculated as {effective_mini_batch_size_for_split} from current_actual_batch_size={current_actual_batch_size} and configured_critic_ppo_mini_batch_size={configured_critic_ppo_mini_batch_size}. Setting to 1.") + effective_mini_batch_size_for_split = 1 + + + if effective_mini_batch_size_for_split < configured_critic_ppo_mini_batch_size: + print(f"[DP_Critic.update_critic] Adjusting PPO mini-batch size for critic update from configured {configured_critic_ppo_mini_batch_size} to actual {effective_mini_batch_size_for_split} due to small input batch size ({current_actual_batch_size}).") + try: + dataloader = batch.split(effective_mini_batch_size_for_split) + except Exception as e: + print(f"[DP_Critic.update_critic] Error during batch.split with effective_mini_batch_size={effective_mini_batch_size_for_split}: {e}") + print(f"[DP_Critic.update_critic] Batch keys: {batch.keys()}, input_ids shape: {batch['input_ids'].shape if 'input_ids' in batch else 'N/A'}") + # Keep dataloader as empty list to skip epochs + + # Try to log the number of mini-batches. + # Note: If dataloader is an iterator, len() might consume it or not be supported. + # The original code used len(), implying it might be a list or has __len__. + try: + num_minibatches = len(dataloader) + print(f"[DP_Critic.update_critic] Created dataloader for {num_minibatches} mini-batches.") + except TypeError: + # This happens if dataloader is an iterator without __len__ + # To get the length, one would need to convert to list, consuming it. + # For now, we'll just note it's an iterator. + print(f"[DP_Critic.update_critic] Created dataloader (iterator type, length not directly logged to avoid consumption).") + + + metrics_to_avg = defaultdict(list) + + for batch_idx, mini_batch_data_container in enumerate(dataloader): + print(f"[DP_Critic.update_critic] Processing mini-batch {batch_idx+1}/{len(dataloader)}.") # split batch into micro_batches - mini_batch = data - if self.config.use_dynamic_bsz: - max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size - micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len) + # mini_batch = mini_batch_data_container + # Get dynamic batching config from self.config (critic's own config), not data.meta_info for update loop + use_dynamic_bsz_update = self.config.get('use_dynamic_bsz', False) + if use_dynamic_bsz_update: + # Use critic's configured max token length + max_token_len = self.config.get('ppo_max_token_len_per_gpu', 2048) * self.ulysses_sequence_parallel_size + print(f"[DP_Critic.update_critic] Using dynamic micro-batch for mini-batch {batch_idx+1}. max_token_len: {max_token_len}") + micro_batches, _ = rearrange_micro_batches(batch=mini_batch_data_container, max_token_len=max_token_len) else: - micro_batches = mini_batch.split(self.config.ppo_micro_batch_size) - + fixed_micro_batch_size_update = self.config.ppo_micro_batch_size + print(f"[DP_Critic.update_critic] Using fixed micro-batch size for mini-batch {batch_idx+1}: {fixed_micro_batch_size_update}") + micro_batches = mini_batch_data_container.split(fixed_micro_batch_size_update) + + print(f"[DP_Critic.update_critic] Mini-batch {batch_idx+1} split into {len(micro_batches)} micro-batches.") self.critic_optimizer.zero_grad() + print(f"[DP_Critic.update_critic] Optimizer zero_grad done for mini-batch {batch_idx+1}.") - for micro_batch_idx, data in enumerate(micro_batches): - # --- DEBUG: Log device before potential .cuda() --- - before_cuda_device = 'Unknown' - target_device = torch.device(f'cuda:{torch.cuda.current_device()}') if torch.cuda.is_available() else torch.device('cpu') - needs_move = False - if 'input_ids' in data: - before_cuda_device = data['input_ids'].device - if before_cuda_device != target_device: - needs_move = True - print(f"[DP_Critic.update_critic] Micro-batch {batch_idx}-{micro_batch_idx} device BEFORE move check: {before_cuda_device}") - # --- END DEBUG --- - - # Conditional .cuda() call - if needs_move and torch.cuda.is_available(): - print(f"[DP_Critic.update_critic] Moving micro-batch {batch_idx}-{micro_batch_idx} from {before_cuda_device} to {target_device}") - data = data.to(target_device) - elif not torch.cuda.is_available(): - print(f"[DP_Critic.update_critic] WARNING: CUDA not available, cannot move micro-batch {batch_idx}-{micro_batch_idx}") - else: - print(f"[DP_Critic.update_critic] Micro-batch {batch_idx}-{micro_batch_idx} already on target device {target_device}. Skipping move.") - - # --- DEBUG: Log device after potential .cuda() --- - after_cuda_device = 'Unknown' - if 'input_ids' in data: - after_cuda_device = data['input_ids'].device - print(f"[DP_Critic.update_critic] Micro-batch {batch_idx}-{micro_batch_idx} device AFTER move check: {after_cuda_device}") - # --- END DEBUG --- + for i, micro_batch_data in enumerate(micro_batches): + print(f"[DP_Critic.update_critic] Forward/Backward for micro-batch {i+1}/{len(micro_batches)} of mini-batch {batch_idx+1}.") + # Assuming FSDP handles device placement, but check device. + micro_batch_data_cuda = micro_batch_data.cuda() + print(f"[DP_Critic.update_critic] Micro-batch {i+1} input_ids device: {micro_batch_data_cuda['input_ids'].device}") - input_ids = data['input_ids'] - responses = data['responses'] - attention_mask = data['attention_mask'] - position_ids = data['position_ids'] - values = data['values'] - returns = data['returns'] + # input_ids = micro_batch_data_cuda['input_ids'] # Not directly needed for loss + responses = micro_batch_data_cuda['responses'] + attention_mask = micro_batch_data_cuda['attention_mask'] + # position_ids = micro_batch_data_cuda['position_ids'] # Needed by _forward_micro_batch + values = micro_batch_data_cuda['values'] + returns = micro_batch_data_cuda['returns'] response_length = responses.size(1) + # Mask for loss calculation corresponds to the response part where values/returns are defined eos_mask = attention_mask[:, -response_length - 1:-1] - # --- DEBUG: Log device before forward pass --- - forward_input_device = 'Unknown' - if 'input_ids' in data: - forward_input_device = data['input_ids'].device - print(f"[DP_Critic.update_critic] Micro-batch {batch_idx}-{micro_batch_idx} device BEFORE forward pass: {forward_input_device}") - # --- END DEBUG --- - - vpreds = self._forward_micro_batch(data) - - # --- DEBUG: Log vpreds device --- - print(f"[DP_Critic.update_critic] Micro-batch {batch_idx}-{micro_batch_idx} vpreds device: {vpreds.device}") - # --- END DEBUG --- + vpreds = self._forward_micro_batch(micro_batch_data_cuda) + print(f"[DP_Critic.update_critic] Micro-batch {i+1} vpreds shape: {vpreds.shape}") # assert not torch.any(torch.isnan(vpreds)).item() @@ -324,19 +310,29 @@ def update_critic(self, data: DataProto): returns=returns, eos_mask=eos_mask, cliprange_value=self.config.cliprange_value) - loss = vf_loss / self.gradient_accumulation + + # Normalize loss by gradient accumulation steps + # Determine accumulation steps from config + gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size + loss = vf_loss / gradient_accumulation + print(f"[DP_Critic.update_critic] Micro-batch {i+1} losses: vf_loss={vf_loss.item():.4f}, clipfrac={vf_clipfrac.item():.4f}, final_loss={loss.item():.4f}") loss.backward() + print(f"[DP_Critic.update_critic] Micro-batch {i+1} backward pass done.") - data = { + loss_data_metrics = { 'critic/vf_loss': vf_loss.detach().item(), 'critic/vf_clipfrac': vf_clipfrac.detach().item(), 'critic/vpred_mean': masked_mean(vpreds, eos_mask).detach().item(), } - - append_to_dict(metrics, data) + append_to_dict(metrics_to_avg, loss_data_metrics) grad_norm = self._optimizer_step() - data = {'critic/grad_norm': grad_norm.detach().item()} - append_to_dict(metrics, data) + optimizer_step_metrics = {'critic/grad_norm': grad_norm.detach().item()} + append_to_dict(metrics, optimizer_step_metrics) + print(f"[DP_Critic.update_critic] Optimizer step done for mini-batch {batch_idx+1}. Grad norm: {grad_norm.item():.4f}") + self.critic_optimizer.zero_grad() - return metrics + print(f"[DP_Critic.update_critic] Final optimizer zero_grad. Exiting. Metrics: {metrics}") + + # BasePPOCritic expects a DataProto containing metrics + return DataProto(meta_info={'metrics': metrics}) # Wrap metrics in DataProto diff --git a/verl/workers/critic/dp_critic_changed.py b/verl/workers/critic/dp_critic_changed.py new file mode 100644 index 00000000..e6f3c8ec --- /dev/null +++ b/verl/workers/critic/dp_critic_changed.py @@ -0,0 +1,342 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Implement a multiprocess PPOCritic +""" +import itertools +from typing import Iterable + +import torch +import torch.distributed +from torch import nn, optim +import tensordict + +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + +from verl import DataProto +from verl.trainer.ppo import core_algos +from verl.workers.critic import BasePPOCritic +from verl.utils.py_functional import append_to_dict +from verl.utils.torch_functional import masked_mean +from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad +from verl.utils.seqlen_balancing import rearrange_micro_batches, get_reverse_idx + +from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis + +__all__ = ['DataParallelPPOCritic'] + + +class DataParallelPPOCritic(BasePPOCritic): + + def __init__(self, config, critic_module: nn.Module, critic_optimizer: optim.Optimizer): + super().__init__(config=config) + self.critic_module = critic_module + self.critic_optimizer = critic_optimizer + self.use_remove_padding = self.config.model.get('use_remove_padding', False) + print(f'Critic use_remove_padding={self.use_remove_padding}') + + assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size == 0 + self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size + + self.ulysses_sequence_parallel_size = self.config.get('ulysses_sequence_parallel_size', 1) + + def _forward_micro_batch(self, micro_batch): + response_length = micro_batch['responses'].size(-1) + with torch.autocast(device_type='cuda', dtype=torch.bfloat16): + input_ids = micro_batch['input_ids'] + batch, seqlen = input_ids.shape + attention_mask = micro_batch['attention_mask'] + position_ids = micro_batch['position_ids'] + + if self.use_remove_padding: + input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), + attention_mask) # input_ids_rmpad (total_nnz, ...) + input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) + + # unpad the position_ids to align the rotary + position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), + indices).transpose(0, 1) + + # pad and slice the inputs if sp > 1 + if self.ulysses_sequence_parallel_size > 1: + input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(input_ids_rmpad, \ + position_ids_rmpad, \ + sp_size=self.ulysses_sequence_parallel_size) + + # only pass input_ids and position_ids to enable flash_attn_varlen + output = self.critic_module(input_ids=input_ids_rmpad, + attention_mask=None, + position_ids=position_ids_rmpad, + use_cache=False) # prevent model thinks we are generating + values_rmpad = output.logits + values_rmpad = values_rmpad.squeeze(0) # (total_nnz) + + # gather output if sp > 1 + if self.ulysses_sequence_parallel_size > 1: + values_rmpad = gather_outpus_and_unpad(values_rmpad, + gather_dim=0, + unpad_dim=0, + padding_size=pad_size) + + # pad it back + values = pad_input(values_rmpad, indices=indices, batch=batch, seqlen=seqlen).squeeze(-1) + values = values[:, -response_length - 1:-1] + else: + output = self.critic_module(input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=False) # prevent model thinks we are generating + values = output.logits + values = values[:, -response_length - 1:-1].squeeze(-1) + return values + + def _optimizer_step(self): + assert self.config.grad_clip is not None + + if isinstance(self.critic_module, FSDP): + grad_norm = self.critic_module.clip_grad_norm_(self.config.grad_clip) + else: + grad_norm = torch.nn.utils.clip_grad_norm_(self.critic_module.parameters(), max_norm=self.config.grad_clip) + self.critic_optimizer.step() + return grad_norm + + def compute_values(self, data: DataProto) -> torch.Tensor: + self.critic_module.eval() + micro_batch_size = data.meta_info['micro_batch_size'] + select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids'] + + # --- DEBUG: Record original data device --- + original_device = 'Unknown' + if 'input_ids' in data.batch: + original_device = data.batch['input_ids'].device + print(f"[DP_Critic.compute_values] Start - Original data device: {original_device}") + + batch = data.select(batch_keys=select_keys).batch + + # --- DEBUG: Record device after select --- + select_device = 'Unknown' + if 'input_ids' in batch: + select_device = batch['input_ids'].device + print(f"[DP_Critic.compute_values] Device AFTER select: {select_device}") + + use_dynamic_bsz = data.meta_info['use_dynamic_bsz'] + + # --- FIX: Move data to CPU before split --- + print(f"[DP_Critic.compute_values] Moving batch to CPU before split") + + # Fix error: Use TensorDict constructor instead of plain dictionary + batch_cpu_dict = {k: v.to('cpu') if isinstance(v, torch.Tensor) else v for k, v in batch.items()} + batch_cpu = tensordict.TensorDict(source=batch_cpu_dict, batch_size=batch.batch_size) + print(f"[DP_Critic.compute_values] Created TensorDict on CPU with batch_size={batch_cpu.batch_size}") + + if use_dynamic_bsz: + # split using dynamic bsz + max_token_len = data.meta_info['max_token_len'] * self.ulysses_sequence_parallel_size + print(f"[DP_Critic.compute_values] Using dynamic batch size with max_token_len={max_token_len}") + micro_batches, indices = rearrange_micro_batches(batch=batch_cpu, max_token_len=max_token_len) + else: + print(f"[DP_Critic.compute_values] Using fixed batch size with micro_batch_size={micro_batch_size}") + micro_batches = batch_cpu.split(micro_batch_size) + + values_lst = [] + for mb_idx, micro_batch in enumerate(micro_batches): + # --- DEBUG: Record micro_batch device --- + mb_device = 'Unknown' + if 'input_ids' in micro_batch: + mb_device = micro_batch['input_ids'].device + print(f"[DP_Critic.compute_values] Micro-batch {mb_idx} device BEFORE potential .cuda(): {mb_device}") + + # Conditionally move to CUDA + target_device = torch.device(f'cuda:{torch.cuda.current_device()}') if torch.cuda.is_available() else torch.device('cpu') + needs_move = False + if 'input_ids' in micro_batch: + if micro_batch['input_ids'].device != target_device: + needs_move = True + + if needs_move and torch.cuda.is_available(): + print(f"[DP_Critic.compute_values] Moving micro-batch {mb_idx} to {target_device}") + micro_batch = micro_batch.to(target_device) + elif not torch.cuda.is_available(): + print(f"[DP_Critic.compute_values] WARNING: CUDA not available. Staying on CPU.") + else: + print(f"[DP_Critic.compute_values] Micro-batch {mb_idx} already on target device. Skipping move.") + + # --- DEBUG: Record device after move --- + after_mb_device = 'Unknown' + if 'input_ids' in micro_batch: + after_mb_device = micro_batch['input_ids'].device + print(f"[DP_Critic.compute_values] Micro-batch {mb_idx} device AFTER potential .cuda(): {after_mb_device}") + + with torch.no_grad(): + values = self._forward_micro_batch(micro_batch) + # --- DEBUG: Record values device --- + print(f"[DP_Critic.compute_values] Micro-batch {mb_idx} values device: {values.device}") + values_lst.append(values) + + print(f"[DP_Critic.compute_values] Concatenating {len(values_lst)} micro-batches") + values = torch.concat(values_lst, dim=0) + print(f"[DP_Critic.compute_values] Concatenated values device: {values.device}") + + responses = data.batch['responses'] + attention_mask = data.batch['attention_mask'] + response_length = responses.size(1) + + # Ensure values and attention_mask are on the same device + if values.device != attention_mask.device: + print(f"[DP_Critic.compute_values] Moving values from {values.device} to {attention_mask.device}") + values = values.to(attention_mask.device) + + values = values * attention_mask[:, -response_length - 1:-1] + + if use_dynamic_bsz: + indices = list(itertools.chain.from_iterable(indices)) + assert len(indices) == values.size(0), f"{len(indices)} vs. {values.size()}" + revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long, device=values.device) + values = values[revert_indices] + + print(f"[DP_Critic.compute_values] Final values shape: {values.shape}, device: {values.device}") + return values + + def update_critic(self, data: DataProto): + # make sure we are in training mode + self.critic_module.train() + metrics = {} + + # --- DEBUG: Log initial input data device --- + initial_device = 'Unknown' + if 'input_ids' in data.batch: + initial_device = data.batch['input_ids'].device + print(f"[DP_Critic.update_critic] Start - Input data device: {initial_device}") + # --- END DEBUG --- + + select_keys = ['input_ids', 'responses', 'attention_mask', 'position_ids', 'values', 'returns'] + # --- DEBUG: Log device before select --- + print(f"[DP_Critic.update_critic] Device BEFORE select: {initial_device}") + batch = data.select(batch_keys=select_keys).batch + # --- DEBUG: Log device after select --- + select_device = 'Unknown' + if 'input_ids' in batch: + select_device = batch['input_ids'].device + print(f"[DP_Critic.update_critic] Device AFTER select: {select_device}") + # --- END DEBUG --- + + # --- Key fix: Move data to CPU before split --- + print(f"[DP_Critic.update_critic] Moving batch to CPU before split") + # Fix error: Use TensorDict constructor instead of plain dictionary + batch_cpu_dict = {k: v.to('cpu') if isinstance(v, torch.Tensor) else v for k, v in batch.items()} + batch_cpu = tensordict.TensorDict(source=batch_cpu_dict, batch_size=batch.batch_size) + print(f"[DP_Critic.update_critic] Created TensorDict on CPU with batch_size={batch_cpu.batch_size}") + + # Split to make minibatch iterator for updating the actor + # See PPO paper for details. https://arxiv.org/abs/1707.06347 + # --- DEBUG: Log device before split --- + split_device = 'cpu' # Should be CPU at this point + print(f"[DP_Critic.update_critic] Device BEFORE split: {split_device}") + dataloader = batch_cpu.split(self.config.ppo_mini_batch_size) + # --- DEBUG: Log device after split (dataloader is iterator) --- + print(f"[DP_Critic.update_critic] Dataloader created after split") + # --- END DEBUG --- + + for batch_idx, data in enumerate(dataloader): + # --- DEBUG: Log mini_batch device before micro-batch split --- + mb_device = 'Unknown' + if 'input_ids' in data: + mb_device = data['input_ids'].device + print(f"[DP_Critic.update_critic] Mini-batch {batch_idx} device: {mb_device}") + # --- END DEBUG --- + + # split batch into micro_batches + mini_batch = data + if self.config.use_dynamic_bsz: + max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size + micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len) + else: + micro_batches = mini_batch.split(self.config.ppo_micro_batch_size) + + self.critic_optimizer.zero_grad() + + for micro_batch_idx, data in enumerate(micro_batches): + # --- DEBUG: Log device before potential .cuda() --- + before_cuda_device = 'Unknown' + target_device = torch.device(f'cuda:{torch.cuda.current_device()}') if torch.cuda.is_available() else torch.device('cpu') + needs_move = False + if 'input_ids' in data: + before_cuda_device = data['input_ids'].device + if before_cuda_device != target_device: + needs_move = True + print(f"[DP_Critic.update_critic] Micro-batch {batch_idx}-{micro_batch_idx} device BEFORE move check: {before_cuda_device}") + # --- END DEBUG --- + + # Conditional .cuda() call + if needs_move and torch.cuda.is_available(): + print(f"[DP_Critic.update_critic] Moving micro-batch {batch_idx}-{micro_batch_idx} from {before_cuda_device} to {target_device}") + data = data.to(target_device) + elif not torch.cuda.is_available(): + print(f"[DP_Critic.update_critic] WARNING: CUDA not available, cannot move micro-batch {batch_idx}-{micro_batch_idx}") + else: + print(f"[DP_Critic.update_critic] Micro-batch {batch_idx}-{micro_batch_idx} already on target device {target_device}. Skipping move.") + + # --- DEBUG: Log device after potential .cuda() --- + after_cuda_device = 'Unknown' + if 'input_ids' in data: + after_cuda_device = data['input_ids'].device + print(f"[DP_Critic.update_critic] Micro-batch {batch_idx}-{micro_batch_idx} device AFTER move check: {after_cuda_device}") + # --- END DEBUG --- + + input_ids = data['input_ids'] + responses = data['responses'] + attention_mask = data['attention_mask'] + position_ids = data['position_ids'] + values = data['values'] + returns = data['returns'] + response_length = responses.size(1) + + eos_mask = attention_mask[:, -response_length - 1:-1] + + # --- DEBUG: Log device before forward pass --- + forward_input_device = 'Unknown' + if 'input_ids' in data: + forward_input_device = data['input_ids'].device + print(f"[DP_Critic.update_critic] Micro-batch {batch_idx}-{micro_batch_idx} device BEFORE forward pass: {forward_input_device}") + # --- END DEBUG --- + + vpreds = self._forward_micro_batch(data) + + # --- DEBUG: Log vpreds device --- + print(f"[DP_Critic.update_critic] Micro-batch {batch_idx}-{micro_batch_idx} vpreds device: {vpreds.device}") + # --- END DEBUG --- + + # assert not torch.any(torch.isnan(vpreds)).item() + + vf_loss, vf_clipfrac = core_algos.compute_value_loss(vpreds=vpreds, + values=values, + returns=returns, + eos_mask=eos_mask, + cliprange_value=self.config.cliprange_value) + loss = vf_loss / self.gradient_accumulation + loss.backward() + + data = { + 'critic/vf_loss': vf_loss.detach().item(), + 'critic/vf_clipfrac': vf_clipfrac.detach().item(), + 'critic/vpred_mean': masked_mean(vpreds, eos_mask).detach().item(), + } + + append_to_dict(metrics, data) + + grad_norm = self._optimizer_step() + data = {'critic/grad_norm': grad_norm.detach().item()} + append_to_dict(metrics, data) + self.critic_optimizer.zero_grad() + return metrics diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 4a76ecbe..c6e0d176 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -18,6 +18,7 @@ import logging import os import warnings +import json import torch import torch.distributed @@ -739,6 +740,13 @@ def init_model(self): @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def compute_values(self, data: DataProto): + # === Add Debug Log HERE === + print(f"[FSDP_CriticWorker.compute_values @ Rank {self.rank}] Received data. Meta Info: {data.meta_info}") + if 'micro_batch_size' in data.meta_info: + print(f"[FSDP_CriticWorker.compute_values @ Rank {self.rank}] micro_batch_size in received meta_info: {data.meta_info['micro_batch_size']}") + else: + print(f"[FSDP_CriticWorker.compute_values @ Rank {self.rank}] WARNING: 'micro_batch_size' NOT FOUND in received meta_info!") + # === End Debug Log === data = data.to('cuda') if self._is_offload_param: @@ -746,17 +754,19 @@ def compute_values(self, data: DataProto): device_id=torch.cuda.current_device(), load_grad=self._is_offload_grad) micro_batch_size = self.config.forward_micro_batch_size + if micro_batch_size == 0: + micro_batch_size = 1 data.meta_info['micro_batch_size'] = micro_batch_size data.meta_info['max_token_len'] = self.config.forward_max_token_len_per_gpu data.meta_info['use_dynamic_bsz'] = self.config.use_dynamic_bsz # perform forward computation with self.ulysses_sharding_manager: data = self.ulysses_sharding_manager.preprocess_data(data=data) - values = self.critic.compute_values(data=data) - output = DataProto.from_dict(tensors={'values': values}) + output = self.critic.compute_values(data=data) + # No need to recreate a DataProto output = self.ulysses_sharding_manager.postprocess_data(data=output) - # output = output.to('cpu') + output = output.to('cpu') if self._is_offload_param: offload_fsdp_param_and_grad(module=self.critic_module, offload_grad=self._is_offload_grad) torch.cuda.empty_cache() @@ -764,6 +774,9 @@ def compute_values(self, data: DataProto): @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def update_critic(self, data: DataProto): + # === Add Debug Log HERE === + print(f"[FSDP_CriticWorker.update_critic @ Rank {self.rank}] Received data. Meta Info: {data.meta_info}") + # === End Debug Log === data = data.to('cuda') if self._is_offload_param: load_fsdp_param_and_grad(module=self.critic_module, @@ -775,20 +788,36 @@ def update_critic(self, data: DataProto): # perform forward computation with self.ulysses_sharding_manager: data = self.ulysses_sharding_manager.preprocess_data(data=data) + + actual_metrics_dict = {} # Initialize an empty dictionary with Timer(name='update_critic', logger=None) as timer: - metrics = self.critic.update_critic(data=data) + # result_proto is the DataProto object from self.critic.update_critic + result_proto = self.critic.update_critic(data=data) delta_time = timer.last + # Extract the metrics dictionary from the result_proto + if isinstance(result_proto, DataProto) and hasattr(result_proto, 'meta_info') and 'metrics' in result_proto.meta_info: + if isinstance(result_proto.meta_info['metrics'], dict): + actual_metrics_dict = result_proto.meta_info['metrics'] + else: + print(f"[FSDP CriticWorker.update_critic] Warning: result_proto.meta_info['metrics'] is not a dict, it's a {type(result_proto.meta_info['metrics'])}. Initializing empty metrics dict.") + else: + print(f"[FSDP CriticWorker.update_critic] Warning: Could not extract metrics dict from result_proto. Initializing empty metrics dict. Result_proto type: {type(result_proto)}") + global_num_tokens = data.meta_info['global_token_num'] estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) - metrics['mfu/critic'] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size + + # Now operate on the actual_metrics_dict + actual_metrics_dict['mfu/critic'] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size self.critic_lr_scheduler.step() lr = self.critic_lr_scheduler.get_last_lr()[0] - metrics['critic/lr'] = lr + # Operate on the extracted dict + actual_metrics_dict['critic/lr'] = lr - output = DataProto(batch=None, meta_info={'metrics': metrics}) + # Create the output DataProto using the modified actual_metrics_dict + output = DataProto(batch=None, meta_info={'metrics': actual_metrics_dict}) output = self.ulysses_sharding_manager.postprocess_data(data=output) if self._is_offload_param: