Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torchstore as ts
from datasets import load_dataset
from forge.actors.policy import Policy
from forge.actors.reference_model import ReferenceModel # noqa: F401
from forge.actors.replay_buffer import ReplayBuffer
from forge.actors.trainer import _qwen3_hf_to_vllm
from forge.cli.config import parse
Expand All @@ -30,6 +31,7 @@
from omegaconf import DictConfig
from torch import nn
from torchstore.state_dict_utils import DELIM
from torchtitan.config.job_config import Model as TitanJobModelConfig
from transformers import AutoModelForCausalLM
from vllm.transformers_utils.tokenizer import get_tokenizer

Expand Down Expand Up @@ -330,6 +332,7 @@ async def pad_token(self):

async def main(cfg: DictConfig):
"""Main GRPO training loop with rollout and training processes."""
titan_model = TitanJobModelConfig(name="qwen3", flavor="1.7B")
# Get parameters from config with fallbacks
group_size = cfg.group_size
model = cfg.model
Expand Down Expand Up @@ -381,6 +384,11 @@ async def main(cfg: DictConfig):
RefModel,
model_name=model,
),
# spawn_service(
# ServiceConfig(procs_per_replica=1, num_replicas=1, with_gpus=True),
# ReferenceModel,
# model=titan_model,
# ),
spawn_service(
ServiceConfig(**cfg.reward_actor.service),
RewardActor,
Expand Down
6 changes: 3 additions & 3 deletions src/forge/actors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ def __getattr__(name):
from .replay_buffer import ReplayBuffer

return ReplayBuffer
elif name == "TitanRefModel":
from .reference_actor import TitanRefModel
elif name == "ReferenceModel":
from .reference_model import ReferenceModel

return TitanRefModel
return ReferenceModel
else:
raise AttributeError(f"module {__name__} has no attribute {name}")
Loading
Loading