-
Notifications
You must be signed in to change notification settings - Fork 24
Add YAML config file for grpo.main #141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 13 commits
da21e1d
5c72908
b4d7a61
02d77c6
fd1d38b
f79beee
d8d775a
e423c44
4815c05
77d41e4
58fbb07
f90c41f
f39f9c2
b3c14a8
310ca5d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,22 +4,28 @@ | |
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# Usage: python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml | ||
|
||
import asyncio | ||
import logging | ||
import sys | ||
import uuid | ||
from dataclasses import dataclass | ||
from typing import Any, Callable, Optional | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from datasets import load_dataset | ||
from forge.actors.policy import EngineConfig, Policy, SamplingConfig | ||
from forge.actors.policy import Policy | ||
from forge.actors.replay_buffer import ReplayBuffer | ||
from forge.cli.config import parse | ||
from forge.controller.actor import ForgeActor | ||
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service | ||
from forge.data.rewards import MathReward, ThinkingReward | ||
from forge.util.metric_logging import get_metric_logger | ||
from monarch.actor import endpoint | ||
from omegaconf import DictConfig | ||
from src.forge.data.utils import exclude_service | ||
from torch import nn | ||
from transformers import AutoModelForCausalLM | ||
from vllm.transformers_utils.tokenizer import get_tokenizer | ||
|
@@ -286,11 +292,11 @@ async def forward(self, episode: Episode) -> torch.Tensor: | |
class DatasetActor(ForgeActor): | ||
"""Actor wrapper for HuggingFace dataset to provide async interface.""" | ||
|
||
path: str | ||
revision: str | ||
data_split: str | ||
streaming: bool | ||
model: str | ||
path: str = "openai/gsm8k" | ||
revision: str = "main" | ||
data_split: str = "train" | ||
streaming: bool = True | ||
model: str = "Qwen/Qwen3-1.7B-Base" | ||
|
||
@endpoint | ||
def setup(self): | ||
|
@@ -326,12 +332,17 @@ async def pad_token(self): | |
return self.tokenizer.pad_token_id | ||
|
||
|
||
async def main(): | ||
async def main(cfg: DictConfig): | ||
"""Main GRPO training loop with rollout and training processes.""" | ||
group_size = 4 | ||
model = "Qwen/Qwen3-1.7B-Base" | ||
max_req_tokens = 512 | ||
max_res_tokens = 128 | ||
# Get parameters from config with fallbacks | ||
group_size = cfg.get("group_size", 4) | ||
model = ( | ||
cfg.get("policy", {}) | ||
.get("engine_config", {}) | ||
.get("model", "Qwen/Qwen3-1.7B-Base") | ||
) | ||
max_req_tokens = cfg.get("max_req_tokens", 512) | ||
max_res_tokens = cfg.get("max_res_tokens", 128) | ||
|
||
# ---- Setup WandB Logger ---- # | ||
logger = get_metric_logger( | ||
|
@@ -351,43 +362,37 @@ async def main(): | |
reward_actor, | ||
) = await asyncio.gather( | ||
spawn_service( | ||
ServiceConfig(procs_per_replica=1, num_replicas=1), | ||
ServiceConfig(**cfg.dataset.service), | ||
DatasetActor, | ||
path="openai/gsm8k", | ||
revision="main", | ||
data_split="train", | ||
streaming=True, | ||
model=model, | ||
**exclude_service(cfg.dataset), | ||
), | ||
spawn_service( | ||
ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1), | ||
ServiceConfig(**cfg.policy.service), | ||
Policy, | ||
engine_config=EngineConfig(model=model), | ||
sampling_config=SamplingConfig(n=group_size, max_tokens=max_res_tokens), | ||
**exclude_service(cfg.policy), | ||
), | ||
spawn_service( | ||
ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1), | ||
ServiceConfig(**cfg.trainer.service), | ||
Trainer, | ||
learning_rate=1e-5, | ||
model_name=model, | ||
**exclude_service(cfg.trainer), | ||
), | ||
spawn_service( | ||
ServiceConfig(procs_per_replica=1, num_replicas=1), | ||
ServiceConfig(**cfg.replay_buffer.service), | ||
ReplayBuffer, | ||
batch_size=4, | ||
max_policy_age=1, | ||
**exclude_service(cfg.replay_buffer), | ||
), | ||
spawn_service( | ||
ServiceConfig(procs_per_replica=1, num_replicas=1), | ||
ServiceConfig(**cfg.compute_advantages.service), | ||
ComputeAdvantages, | ||
), | ||
spawn_service( | ||
ServiceConfig(procs_per_replica=1, num_replicas=1, with_gpus=True), | ||
ServiceConfig(**cfg.ref_model.service), | ||
RefModel, | ||
model_name=model, | ||
), | ||
spawn_service( | ||
ServiceConfig(procs_per_replica=1, num_replicas=1), | ||
ServiceConfig(**cfg.reward_actor.service), | ||
RewardActor, | ||
reward_functions=[MathReward(), ThinkingReward()], | ||
), | ||
|
@@ -481,5 +486,10 @@ async def continuous_training(): | |
) | ||
|
||
|
||
@parse | ||
def recipe_main(cfg: DictConfig) -> None: | ||
asyncio.run(main(cfg)) | ||
|
||
|
||
if __name__ == "__main__": | ||
asyncio.run(main()) | ||
sys.exit(recipe_main()) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
# GRPO Training Configuration | ||
|
||
# Global configuration | ||
group_size: 4 | ||
max_req_tokens: 512 | ||
max_res_tokens: 128 | ||
|
||
# Dataset configuration | ||
dataset: | ||
path: "openai/gsm8k" | ||
revision: "main" | ||
data_split: "train" | ||
streaming: true | ||
service: | ||
procs_per_replica: 1 | ||
num_replicas: 1 | ||
with_gpus: false | ||
|
||
# Policy configuration | ||
policy: | ||
engine_config: | ||
model: "Qwen/Qwen3-1.7B-Base" | ||
tensor_parallel_size: 1 | ||
pipeline_parallel_size: 1 | ||
enforce_eager: true | ||
sampling_config: | ||
n: 4 | ||
max_tokens: 128 | ||
temperature: 1.0 | ||
top_p: 1.0 | ||
service: | ||
procs_per_replica: 1 | ||
num_replicas: 1 | ||
with_gpus: true | ||
|
||
# Trainer configuration | ||
trainer: | ||
learning_rate: 1e-5 | ||
service: | ||
procs_per_replica: 1 | ||
num_replicas: 1 | ||
with_gpus: true | ||
|
||
# Replay buffer configuration | ||
replay_buffer: | ||
batch_size: 4 | ||
DNXie marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
max_policy_age: 1 | ||
DNXie marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
service: | ||
procs_per_replica: 1 | ||
num_replicas: 1 | ||
with_gpus: false | ||
|
||
# Compute advantages configuration | ||
compute_advantages: | ||
service: | ||
procs_per_replica: 1 | ||
num_replicas: 1 | ||
with_gpus: false | ||
|
||
# Reference model configuration | ||
ref_model: | ||
service: | ||
procs_per_replica: 1 | ||
num_replicas: 1 | ||
with_gpus: true | ||
|
||
# Reward actor configuration | ||
reward_actor: | ||
service: | ||
procs_per_replica: 1 | ||
num_replicas: 1 | ||
with_gpus: false |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -210,6 +210,14 @@ def batch_to_device(batch: dict, device: torch.device) -> None: | |
batch[k] = v.to(device) | ||
else: | ||
raise ValueError( | ||
f"""To use batch_to_device, all elements in the batch must be a dict, Tensor, or BlockMask with flexattention enabled. | ||
Got key "{k}" with value of type {type(v)}""" | ||
f"To use batch_to_device, all elements in the batch must be a dict, " | ||
f"Tensor, or BlockMask with flexattention enabled. " | ||
f'Got key "{k}" with value of type {type(v)}' | ||
) | ||
|
||
|
||
def exclude_service(config_dict: dict) -> dict: | ||
|
||
"""Remove 'service' key from config dict without modifying original.""" | ||
result = config_dict.copy() | ||
result.pop("service", None) | ||
return result |
Uh oh!
There was an error while loading. Please reload this page.