diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 7fd10736f..c5df455c4 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -112,11 +112,25 @@ def __init__( self.model.parameters(), lr=self.learning_rate ) + # Initialize metrics storage + self.log_dict = {} + self.logger.info(f"Model initialized on {self.device}") + @endpoint + async def get_metrics(self): + """Return metrics dict for external logger to log.""" + return self.log_dict.copy() + @endpoint async def train_step(self, batch: list[Episode]): total_loss = 0.0 + total_kl_loss = 0.0 + total_pg_loss = 0.0 + total_ratio_mean = 0.0 + total_ratio_std = 0.0 + total_response_len = 0.0 + total_advantages = 0.0 num_groups_processed = 0 for episode in batch: @@ -170,6 +184,21 @@ async def train_step(self, batch: list[Episode]): # Total GRPO loss loss = pg_loss + kl_penalty total_loss += loss.item() + total_kl_loss += kl_penalty.item() + total_pg_loss += pg_loss.item() + total_ratio_mean += ratio.detach().float().cpu().numpy().mean() + total_ratio_std += ratio.detach().float().cpu().numpy().std() + + # Calculate average response length for this episode + episode_response_len = sum( + len(response) for response in response_texts + ) / len(response_texts) + total_response_len += episode_response_len + + # Calculate average advantages for this episode + episode_advantages = advantages_tensor.detach().float().cpu().numpy().mean() + total_advantages += episode_advantages + num_groups_processed += len(groups) self.optimizer.zero_grad() @@ -180,7 +209,30 @@ async def train_step(self, batch: list[Episode]): self.optimizer.step() - avg_loss = total_loss / len(batch) if batch else 0.0 + # Compute averaged metrics across the batch + if batch: + avg_loss = total_loss / len(batch) + avg_kl_loss = total_kl_loss / len(batch) + avg_pg_loss = total_pg_loss / len(batch) + avg_ratio_mean = total_ratio_mean / len(batch) + avg_ratio_std = total_ratio_std / len(batch) + avg_response_len = total_response_len / len(batch) + avg_advantages = total_advantages / len(batch) + else: + avg_loss = avg_kl_loss = avg_pg_loss = avg_ratio_mean = avg_ratio_std = ( + avg_response_len + ) = avg_advantages = 0.0 + + # Store averaged metrics for external logging + self.log_dict = { + "loss/total": avg_loss, + "loss/kl": avg_kl_loss, + "loss/policy": avg_pg_loss, + "metrics/ratio_mean": avg_ratio_mean, + "metrics/ratio_std": avg_ratio_std, + "metrics/response_len": avg_response_len, + "metrics/advantages": avg_advantages, + } return {"loss": avg_loss, "groups_processed": num_groups_processed} @@ -460,7 +512,7 @@ async def continuous_rollouts(): print( f"Generated {rollout_count} rollouts w/ average reward {avg_reward}" ) - logger.log("reward/rollout", avg_reward, rollout_count) + logger.log("metrics/reward_per_ten_rollout", avg_reward, rollout_count) async def continuous_training(): training_step = 0 @@ -476,7 +528,11 @@ async def continuous_training(): if training_result: loss_value = training_result.get("loss", 0.0) print(f"Latest loss: {loss_value}") - logger.log("loss/training_step", loss_value, training_step) + + # Get and log detailed metrics + metrics = await trainer.get_metrics.choose() + logger.log_dict(metrics, training_step) + # await trainer.update_weights(policy) print("Starting GRPO training loops...")