Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
import torch
from datasets import load_dataset
from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig
from forge.actors.replay_buffer import ReplayBuffer
from forge.controller import ServiceConfig, spawn_service
from forge.controller.actor import ForgeActor
from forge.util.metric_logging import get_metric_logger
from monarch.actor import endpoint
from transformers import AutoModelForCausalLM, AutoTokenizer

Expand Down Expand Up @@ -388,6 +390,13 @@ async def main():
group_size = 1
model = "Qwen/Qwen3-1.7B"

# ---- Setup WandB Logger ---- #
logger = get_metric_logger(
"wandb",
freq=1,
project="grpo-training",
)

# ---- Setup services ---- #
default_service_cfg = ServiceConfig(
procs_per_replica=1,
Expand Down Expand Up @@ -498,6 +507,7 @@ async def continuous_rollouts():
print(
f"Generated {rollout_count} rollouts w/ average reward {avg_reward}"
)
logger.log("reward/rollout", avg_reward, rollout_count)

async def continuous_training():
training_step = 0
Expand All @@ -511,7 +521,9 @@ async def continuous_training():
if training_step % 10 == 0:
print(f"Completed {training_step} training steps")
if training_result:
print(f"Latest loss: {training_result.get('loss', 'N/A')}")
loss_value = training_result.get("loss", 0.0)
print(f"Latest loss: {loss_value}")
logger.log("loss/training_step", loss_value, training_step)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You probably want to make sure that the input is a dictionary, like here: https://github.com/pytorch/torchtune/blob/67ab86b94de9e7ac7dd9850113ebe69e2bbd307c/recipes/full_finetune_distributed.py#L909

I think that we will have an abundance of metrics coming from dataset and reliability metrics. This is how i envisioned it being used: https://fb.workplace.com/groups/1189731669410969/permalink/1279384097112392/

I understand that its a a simple PR just to get logging started. Just sharing where I think we should land after a few iterations.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah great callout, but lets land this one for now and iterate to that point.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@felipemello1 Thanks for the suggestions! This is a simple PR to simply integrate the logger. In my next PR, I will add logging for various metrics using log_dict.
@joecummings Thanks for approval. I will go ahead and land this one for now.

# await trainer.update_weights(policy)

print("Starting GRPO training loops...")
Expand Down
Loading