Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 59 additions & 3 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,25 @@ def __init__(
self.model.parameters(), lr=self.learning_rate
)

# Initialize metrics storage
self.log_dict = {}

self.logger.info(f"Model initialized on {self.device}")

@endpoint
async def get_metrics(self):
"""Return metrics dict for external logger to log."""
return self.log_dict.copy()

@endpoint
async def train_step(self, batch: list[Episode]):
total_loss = 0.0
total_kl_loss = 0.0
total_pg_loss = 0.0
total_ratio_mean = 0.0
total_ratio_std = 0.0
total_response_len = 0.0
total_advantages = 0.0
num_groups_processed = 0

for episode in batch:
Expand Down Expand Up @@ -170,6 +184,21 @@ async def train_step(self, batch: list[Episode]):
# Total GRPO loss
loss = pg_loss + kl_penalty
total_loss += loss.item()
total_kl_loss += kl_penalty.item()
total_pg_loss += pg_loss.item()
total_ratio_mean += ratio.detach().float().cpu().numpy().mean()
total_ratio_std += ratio.detach().float().cpu().numpy().std()

# Calculate average response length for this episode
episode_response_len = sum(
len(response) for response in response_texts
) / len(response_texts)
total_response_len += episode_response_len

# Calculate average advantages for this episode
episode_advantages = advantages_tensor.detach().float().cpu().numpy().mean()
total_advantages += episode_advantages

num_groups_processed += len(groups)

self.optimizer.zero_grad()
Expand All @@ -180,7 +209,30 @@ async def train_step(self, batch: list[Episode]):

self.optimizer.step()

avg_loss = total_loss / len(batch) if batch else 0.0
# Compute averaged metrics across the batch
if batch:
avg_loss = total_loss / len(batch)
avg_kl_loss = total_kl_loss / len(batch)
avg_pg_loss = total_pg_loss / len(batch)
avg_ratio_mean = total_ratio_mean / len(batch)
avg_ratio_std = total_ratio_std / len(batch)
avg_response_len = total_response_len / len(batch)
avg_advantages = total_advantages / len(batch)
else:
avg_loss = avg_kl_loss = avg_pg_loss = avg_ratio_mean = avg_ratio_std = (
avg_response_len
) = avg_advantages = 0.0

# Store averaged metrics for external logging
self.log_dict = {
"loss/total": avg_loss,
"loss/kl": avg_kl_loss,
"loss/policy": avg_pg_loss,
"metrics/ratio_mean": avg_ratio_mean,
"metrics/ratio_std": avg_ratio_std,
"metrics/response_len": avg_response_len,
"metrics/advantages": avg_advantages,
}

return {"loss": avg_loss, "groups_processed": num_groups_processed}

Expand Down Expand Up @@ -460,7 +512,7 @@ async def continuous_rollouts():
print(
f"Generated {rollout_count} rollouts w/ average reward {avg_reward}"
)
logger.log("reward/rollout", avg_reward, rollout_count)
logger.log("metrics/reward_per_ten_rollout", avg_reward, rollout_count)

async def continuous_training():
training_step = 0
Expand All @@ -476,7 +528,11 @@ async def continuous_training():
if training_result:
loss_value = training_result.get("loss", 0.0)
print(f"Latest loss: {loss_value}")
logger.log("loss/training_step", loss_value, training_step)

# Get and log detailed metrics
metrics = await trainer.get_metrics.choose()
logger.log_dict(metrics, training_step)

# await trainer.update_weights(policy)

print("Starting GRPO training loops...")
Expand Down
Loading