Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 34 additions & 29 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# Usage: python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml

import asyncio
import logging
import uuid
Expand All @@ -13,13 +15,16 @@
import torch
import torch.nn.functional as F
from datasets import load_dataset
from forge.actors.policy import EngineConfig, Policy, SamplingConfig
from forge.actors.policy import Policy
from forge.actors.replay_buffer import ReplayBuffer
from forge.cli.config import parse
from forge.controller.actor import ForgeActor
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service
from forge.data.rewards import MathReward, ThinkingReward
from forge.util.metric_logging import get_metric_logger
from monarch.actor import endpoint
from omegaconf import DictConfig
from src.forge.data.utils import exclude_service
from torch import nn
from transformers import AutoModelForCausalLM
from vllm.transformers_utils.tokenizer import get_tokenizer
Expand Down Expand Up @@ -286,11 +291,11 @@ async def forward(self, episode: Episode) -> torch.Tensor:
class DatasetActor(ForgeActor):
"""Actor wrapper for HuggingFace dataset to provide async interface."""

path: str
revision: str
data_split: str
streaming: bool
model: str
path: str = "openai/gsm8k"
revision: str = "main"
data_split: str = "train"
streaming: bool = True
model: str = "Qwen/Qwen3-1.7B-Base"

@endpoint
def setup(self):
Expand Down Expand Up @@ -326,12 +331,13 @@ async def pad_token(self):
return self.tokenizer.pad_token_id


async def main():
async def main(cfg: DictConfig):
"""Main GRPO training loop with rollout and training processes."""
group_size = 4
model = "Qwen/Qwen3-1.7B-Base"
max_req_tokens = 512
max_res_tokens = 128
# Get parameters from config with fallbacks
group_size = cfg.group_size
model = cfg.model
max_req_tokens = cfg.max_req_tokens
max_res_tokens = cfg.max_res_tokens

# ---- Setup WandB Logger ---- #
logger = get_metric_logger(
Expand All @@ -351,43 +357,37 @@ async def main():
reward_actor,
) = await asyncio.gather(
spawn_service(
ServiceConfig(procs_per_replica=1, num_replicas=1),
ServiceConfig(**cfg.dataset.service),
DatasetActor,
path="openai/gsm8k",
revision="main",
data_split="train",
streaming=True,
model=model,
**exclude_service(cfg.dataset),
),
spawn_service(
ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1),
ServiceConfig(**cfg.policy.service),
Policy,
engine_config=EngineConfig(model=model),
sampling_config=SamplingConfig(n=group_size, max_tokens=max_res_tokens),
**exclude_service(cfg.policy),
),
spawn_service(
ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1),
ServiceConfig(**cfg.trainer.service),
Trainer,
learning_rate=1e-5,
model_name=model,
**exclude_service(cfg.trainer),
),
spawn_service(
ServiceConfig(procs_per_replica=1, num_replicas=1),
ServiceConfig(**cfg.replay_buffer.service),
ReplayBuffer,
batch_size=4,
max_policy_age=1,
**exclude_service(cfg.replay_buffer),
),
spawn_service(
ServiceConfig(procs_per_replica=1, num_replicas=1),
ServiceConfig(**cfg.compute_advantages.service),
ComputeAdvantages,
),
spawn_service(
ServiceConfig(procs_per_replica=1, num_replicas=1, with_gpus=True),
ServiceConfig(**cfg.ref_model.service),
RefModel,
model_name=model,
),
spawn_service(
ServiceConfig(procs_per_replica=1, num_replicas=1),
ServiceConfig(**cfg.reward_actor.service),
RewardActor,
reward_functions=[MathReward(), ThinkingReward()],
),
Expand Down Expand Up @@ -481,5 +481,10 @@ async def continuous_training():
)


@parse
def recipe_main(cfg: DictConfig) -> None:
asyncio.run(main(cfg))


if __name__ == "__main__":
asyncio.run(main())
recipe_main()
74 changes: 74 additions & 0 deletions apps/grpo/qwen3_1_7b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# GRPO Training Configuration

# Global configuration
group_size: 4
batch_size: 4
max_req_tokens: 512
max_res_tokens: 128
model: "Qwen/Qwen3-1.7B-Base"

# Dataset configuration
dataset:
path: "openai/gsm8k"
revision: "main"
data_split: "train"
streaming: true
service:
procs_per_replica: 1
num_replicas: 1
with_gpus: false

# Policy configuration
policy:
engine_config:
model: ${model}
tensor_parallel_size: 1
pipeline_parallel_size: 1
enforce_eager: true
sampling_config:
n: 4
max_tokens: 128
temperature: 1.0
top_p: 1.0
service:
procs_per_replica: 1
num_replicas: 1
with_gpus: true

# Trainer configuration
trainer:
learning_rate: 1e-5
service:
procs_per_replica: 1
num_replicas: 1
with_gpus: true

# Replay buffer configuration
replay_buffer:
batch_size: ${batch_size}
max_policy_age: 0
service:
procs_per_replica: 1
num_replicas: 1
with_gpus: false

# Compute advantages configuration
compute_advantages:
service:
procs_per_replica: 1
num_replicas: 1
with_gpus: false

# Reference model configuration
ref_model:
service:
procs_per_replica: 1
num_replicas: 1
with_gpus: true

# Reward actor configuration
reward_actor:
service:
procs_per_replica: 1
num_replicas: 1
with_gpus: false
8 changes: 5 additions & 3 deletions apps/vllm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
"""

import asyncio
import sys

from forge.actors.policy import Policy
from forge.cli.config import parse
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service

from omegaconf import DictConfig
from src.forge.data.utils import exclude_service
from vllm.outputs import RequestOutput


Expand All @@ -29,7 +29,9 @@ async def run(cfg: DictConfig):
print("Spawning service...")

policy = await spawn_service(
ServiceConfig(**cfg.policy.service), Policy, **cfg.policy
ServiceConfig(**cfg.policy.service),
Policy,
**exclude_service(cfg.policy),
)

async with policy.session():
Expand All @@ -55,4 +57,4 @@ def recipe_main(cfg: DictConfig) -> None:


if __name__ == "__main__":
sys.exit(recipe_main())
recipe_main()
4 changes: 2 additions & 2 deletions src/forge/actors/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
class ReplayBuffer(ForgeActor):
"""Simple in-memory replay buffer implementation."""

batch_size: int
max_policy_age: int
batch_size: int = 4
max_policy_age: int = 0
seed: int | None = None

@endpoint
Expand Down
12 changes: 10 additions & 2 deletions src/forge/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,14 @@ def batch_to_device(batch: dict, device: torch.device) -> None:
batch[k] = v.to(device)
else:
raise ValueError(
f"""To use batch_to_device, all elements in the batch must be a dict, Tensor, or BlockMask with flexattention enabled.
Got key "{k}" with value of type {type(v)}"""
f"To use batch_to_device, all elements in the batch must be a dict, "
f"Tensor, or BlockMask with flexattention enabled. "
f'Got key "{k}" with value of type {type(v)}'
)


def exclude_service(config_dict: dict) -> dict:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fun - we may want to put all of this in its own "frontend" file called omegaconf_frontend (idk, smth like that) so people can keep track of things they might want to switch out with their own frontend.

"""Remove 'service' key from config dict without modifying original."""
result = config_dict.copy()
result.pop("service", None)
return result
Loading