From 10507b5c59ee6d8fe6eeaa6e364d083d57a61898 Mon Sep 17 00:00:00 2001 From: joecummings Date: Wed, 27 Aug 2025 18:19:55 -0700 Subject: [PATCH 1/3] Add missing licenses --- src/forge/cli/config.py | 6 ++++++ src/forge/data/rewards.py | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/src/forge/cli/config.py b/src/forge/cli/config.py index 918823ab6..35eb13d9b 100644 --- a/src/forge/cli/config.py +++ b/src/forge/cli/config.py @@ -1,3 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + import argparse import functools import sys diff --git a/src/forge/data/rewards.py b/src/forge/data/rewards.py index 72973ae95..644c69d1b 100644 --- a/src/forge/data/rewards.py +++ b/src/forge/data/rewards.py @@ -1,3 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + import re from typing import Optional From e096eb115b9d52c26e21d7f41c9b59cf77d6ff1a Mon Sep 17 00:00:00 2001 From: joecummings Date: Sat, 30 Aug 2025 16:46:20 -0700 Subject: [PATCH 2/3] Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- apps/grpo/main.py | 70 ++++++++++++------------- src/forge/actors/__init__.py | 32 +++++------ src/forge/actors/policy.py | 27 ++++++---- src/forge/cli/download.py | 4 +- src/forge/cli/run.py | 6 +-- src/forge/controller/service/replica.py | 4 +- src/forge/data/datasets/packed.py | 6 +-- src/forge/data/datasets/sft_dataset.py | 1 + src/forge/data/tokenizer.py | 6 +-- src/forge/envs/chat.py | 1 + 10 files changed, 82 insertions(+), 75 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 145b6cd48..26875c8f1 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -19,6 +19,8 @@ from forge.data.rewards import MathReward, ThinkingReward from forge.util.metric_logging import get_metric_logger from monarch.actor import endpoint +from torchstore import MultiProcessStore +from torchstore._state_dict_utils import DELIM, push_state_dict from transformers import AutoModelForCausalLM, AutoTokenizer logger = logging.getLogger(__name__) @@ -67,7 +69,13 @@ class Group: class Episode: """Episode container for GRPO rollouts.""" - def __init__(self, episode_id: int, prompt: str, target: str, policy_version: int): + def __init__( + self, + episode_id: int, + prompt: str, + target: str, + policy_version: int | None = None, + ): self.episode_id = episode_id self.prompt = prompt self.target = target @@ -87,11 +95,13 @@ def __init__( beta: float = 0.1, model_name: str = "", device: torch.device | None = None, + store: MultiProcessStore | None = None, ): super().__init__() self.learning_rate = learning_rate self.beta = beta # KL penalty coefficient self.model_name = model_name + self.store = store # Set device if device is None: @@ -189,31 +199,13 @@ async def train_step(self, batch: list[Episode]): return {"loss": avg_loss, "groups_processed": num_groups_processed} @endpoint - async def update_weights(self, policy_actor): + async def save_weights(self): """Update policy model weights with trainer's current weights.""" - # Time how long it takes to update weights start_time = time.time() - - # Set model to eval mode for weight extraction - self.model.eval() - - # Extract current model state dict - model_state_dict = self.model.state_dict() - - # Convert tensors to CPU for transfer (if they're on GPU) - cpu_state_dict = {} - for key, tensor in model_state_dict.items(): - cpu_state_dict[key] = tensor.cpu() if tensor.is_cuda else tensor - - # Update the policy actor's model weights - await policy_actor.update_model_weights.choose(cpu_state_dict) - - # Set model back to training mode - self.model.train() - - # Log the time taken + assert self.store is not None, "Store must be provided to save weights" + await push_state_dict(self.store, self.model.state_dict(), "model_state_dict") end_time = time.time() - self.logger.info(f"Updating weights took {end_time - start_time:.2f} seconds") + self.logger.info(f"Saving weights took {end_time - start_time:.2f} seconds") class RewardActor(ForgeActor): @@ -347,12 +339,12 @@ async def main(): group_size = 1 model = "Qwen/Qwen3-1.7B" - # ---- Setup WandB Logger ---- # logger = get_metric_logger( "wandb", freq=1, project="grpo-training", ) + store = await MultiProcessStore.create_store() # ---- Setup services ---- # ( @@ -415,6 +407,9 @@ async def main(): print("All services initialized successfully!") + # print("Trying to save weights to torchstore...") + await trainer.save_weights.choose() + # ---- Core RL loops ---- # async def continuous_rollouts(): rollout_count = 0 @@ -424,14 +419,13 @@ async def continuous_rollouts(): print("Dataloader is empty, exiting continuous rollout") return prompt, target = sample["question"], sample["answer"] - version = 0 # await policy.get_current_version.choose() + actions, policy_version = await policy.generate.choose(prompt) episode = Episode( episode_id=rollout_count, prompt=prompt, target=target, - policy_version=version, + policy_version=policy_version, ) - actions = await policy.generate.choose(prompt) for action in actions: ref_logprobs = await ref_model.forward.choose(action.token_ids) reward = await reward_actor.evaluate_response.choose( @@ -462,23 +456,27 @@ async def continuous_rollouts(): logger.log("reward/rollout", avg_reward, rollout_count) async def continuous_training(): + on_policy_version = 0 training_step = 0 while True: - batch = await replay_buffer.sample.choose(curr_policy_version=0) + batch = await replay_buffer.sample.choose( + curr_policy_version=on_policy_version + ) if batch is None: await asyncio.sleep(0.1) else: training_result = await trainer.train_step.choose(batch) training_step += 1 if training_step % 10 == 0: - print(f"Completed {training_step} training steps") - if training_result: - loss_value = training_result.get("loss", 0.0) - print(f"Latest loss: {loss_value}") - logger.log("loss/training_step", loss_value, training_step) - # await trainer.update_weights(policy) - - print("Starting GRPO training loops...") + loss_value = training_result.get("loss", 0.0) + print( + f"Completed {training_step} training steps w/ loss: {loss_value}" + ) + logger.log("loss/training_step", loss_value, training_step) + print("Updating policy weights...") + await trainer.save_weights.choose() + + # print("Starting GRPO training loops...") rollout_task = asyncio.create_task(continuous_rollouts()) training_task = asyncio.create_task(continuous_training()) diff --git a/src/forge/actors/__init__.py b/src/forge/actors/__init__.py index c521b813a..983fe4f68 100644 --- a/src/forge/actors/__init__.py +++ b/src/forge/actors/__init__.py @@ -4,25 +4,25 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -__all__ = ["Policy", "PolicyRouter", "RLTrainer", "ReplayBuffer"] +# __all__ = ["Policy", "PolicyRouter", "RLTrainer", "ReplayBuffer"] -def __getattr__(name): - if name == "Policy": - from .policy import Policy +# def __getattr__(name): +# if name == "Policy": +# from .policy import Policy - return Policy - elif name == "PolicyRouter": - from .policy import PolicyRouter +# return Policy +# elif name == "PolicyRouter": +# from .policy import PolicyRouter - return PolicyRouter - elif name == "RLTrainer": - from .trainer import RLTrainer +# return PolicyRouter +# elif name == "RLTrainer": +# from .trainer import RLTrainer - return RLTrainer - elif name == "ReplayBuffer": - from .replay_buffer import ReplayBuffer +# return RLTrainer +# elif name == "ReplayBuffer": +# from .replay_buffer import ReplayBuffer - return ReplayBuffer - else: - raise AttributeError(f"module {__name__} has no attribute {name}") +# return ReplayBuffer +# else: +# raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 4a51f7225..86d81be32 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -13,6 +13,12 @@ from typing import Dict, List import torch + +from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh + +from forge.data.sharding import VLLMSharding +from forge.interfaces import Policy as PolicyInterface +from forge.types import ProcessConfig from monarch.actor import current_rank, endpoint, ProcMesh from torchstore import MultiProcessStore from torchstore._state_dict_utils import DELIM @@ -37,12 +43,6 @@ from vllm.v1.structured_output import StructuredOutputManager from vllm.worker.worker_base import WorkerWrapperBase -from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh - -from forge.data.sharding import VLLMSharding -from forge.interfaces import Policy as PolicyInterface -from forge.types import ProcessConfig - logger = logging.getLogger(__name__) @@ -216,7 +216,9 @@ def start_processing(self): self._run_task = asyncio.create_task(self.run()) @endpoint - async def generate(self, prompt: str, priority: int = 0) -> List[CompletionOutput]: + async def generate( + self, prompt: str, priority: int = 0 + ) -> tuple[List[CompletionOutput], int]: self.request_id += 1 % sys.maxsize request_id = str(self.request_id) # implement from a counter @@ -273,7 +275,8 @@ async def generate(self, prompt: str, priority: int = 0) -> List[CompletionOutpu request_fut = asyncio.Future() self.requests[request_id] = (parent_req, request_fut) - return await request_fut + x = await request_fut + return x, self.weights_version # Abstracted to match vllm # https://github.com/vllm-project/vllm/blob/0e3bb543f064eb416bca4f6f3013efa3830b12f7/vllm/v1/engine/core.py#L419 @@ -313,9 +316,13 @@ async def run(self): fut.set_result(request_output.outputs) @endpoint - async def update_weights(self): + async def update_weights(self) -> int: """Update the policy weights.""" - pass + # Wait for all current requests to finish, then publish model weights + await asyncio.gather(*self.requests.values()) + await self.policy_worker.update.call() + self.weights_version += 1 + return self.weights_version @endpoint async def stop(self): diff --git a/src/forge/cli/download.py b/src/forge/cli/download.py index 5938a1edd..69ebde9aa 100644 --- a/src/forge/cli/download.py +++ b/src/forge/cli/download.py @@ -13,11 +13,11 @@ from pathlib import Path -from forge.cli.subcommand import Subcommand - from huggingface_hub import snapshot_download from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError +from forge.cli.subcommand import Subcommand + # TODO: update this REPO_ID_FNAME = "original_repo_id" diff --git a/src/forge/cli/run.py b/src/forge/cli/run.py index da13e804e..4a556c1f8 100644 --- a/src/forge/cli/run.py +++ b/src/forge/cli/run.py @@ -11,12 +11,12 @@ from pathlib import Path -import forge -from forge.cli.subcommand import Subcommand - from torch.distributed.elastic.multiprocessing.errors import record from torch.distributed.run import get_args_parser as get_torchrun_args_parser, run +import forge +from forge.cli.subcommand import Subcommand + ROOT = Path(forge.__file__).parent.parent diff --git a/src/forge/controller/service/replica.py b/src/forge/controller/service/replica.py index b84e5eec7..a154f15a4 100644 --- a/src/forge/controller/service/replica.py +++ b/src/forge/controller/service/replica.py @@ -13,11 +13,11 @@ from enum import Enum from typing import Optional -from monarch.actor import ActorError - from forge.controller import ForgeActor from forge.types import ProcessConfig +from monarch.actor import ActorError + logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) diff --git a/src/forge/data/datasets/packed.py b/src/forge/data/datasets/packed.py index 489811960..105921acb 100644 --- a/src/forge/data/datasets/packed.py +++ b/src/forge/data/datasets/packed.py @@ -10,14 +10,14 @@ from typing import Any, Generic, Iterable, Iterator, Optional, TypeVar import torch - -from forge.data import CROSS_ENTROPY_IGNORE_IDX -from forge.data.dataset_metrics import AggregationType, Metric from torch.nn.attention.flex_attention import ( create_block_mask as create_block_mask_flex, ) from torchdata.stateful_dataloader import Stateful +from forge.data import CROSS_ENTROPY_IGNORE_IDX +from forge.data.dataset_metrics import AggregationType, Metric + from .dataset import DatasetInfo, InfiniteTuneIterableDataset logger = logging.getLogger(__name__) diff --git a/src/forge/data/datasets/sft_dataset.py b/src/forge/data/datasets/sft_dataset.py index f3e1ec781..e6f6edcfb 100644 --- a/src/forge/data/datasets/sft_dataset.py +++ b/src/forge/data/datasets/sft_dataset.py @@ -7,6 +7,7 @@ from typing import Any, Callable, Optional import torch + from forge.data import CROSS_ENTROPY_IGNORE_IDX from forge.data.dataset_metrics import DefaultTrainingMetricTransform from forge.data.utils import mask_messages, TuneMessage diff --git a/src/forge/data/tokenizer.py b/src/forge/data/tokenizer.py index 1ff322ad7..3cb90f79c 100644 --- a/src/forge/data/tokenizer.py +++ b/src/forge/data/tokenizer.py @@ -8,13 +8,13 @@ from typing import Any, Optional import jinja2 +from jinja2 import StrictUndefined + +from tokenizers import Tokenizer from forge.data.utils import truncate from forge.interfaces import BaseTokenizer, ModelTokenizer from forge.types import Message -from jinja2 import StrictUndefined - -from tokenizers import Tokenizer class HuggingFaceBaseTokenizer(BaseTokenizer): diff --git a/src/forge/envs/chat.py b/src/forge/envs/chat.py index 4a94c89d9..24a5981a6 100644 --- a/src/forge/envs/chat.py +++ b/src/forge/envs/chat.py @@ -7,6 +7,7 @@ from dataclasses import dataclass, field import torch + from forge.interfaces import Environment, Message, ModelTokenizer, Transform from forge.types import Action, Observation, State From 006c6d1b4960fa70a2f720f7098b1e2b0a5cc6f4 Mon Sep 17 00:00:00 2001 From: joecummings Date: Sat, 30 Aug 2025 17:16:03 -0700 Subject: [PATCH 3/3] Stub --- apps/grpo/main.py | 294 +++++++++++++++++++------------------ src/forge/actors/policy.py | 1 + 2 files changed, 153 insertions(+), 142 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 26875c8f1..2fcbac705 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -347,62 +347,72 @@ async def main(): store = await MultiProcessStore.create_store() # ---- Setup services ---- # - ( - dataloader, - policy, - trainer, - replay_buffer, - compute_advantages, - ref_model, - reward_actor, - ) = await asyncio.gather( - spawn_service( - ServiceConfig(procs_per_replica=1, num_replicas=1), - DatasetActor, - path="openai/gsm8k", - config_name="main", - split="train", - streaming=True, - ), - spawn_service( - ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1), - Policy, - config=PolicyConfig( - worker_params=WorkerConfig(model=model), - sampling_params=SamplingOverrides( - num_samples=group_size, max_tokens=16 - ), - ), - ), - spawn_service( - ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1), - Trainer, - learning_rate=1e-5, - beta=0.1, - model_name=model, - ), - spawn_service( - ServiceConfig(procs_per_replica=1, num_replicas=1), - ReplayBuffer, - batch_size=4, - max_policy_age=1, - ), - spawn_service( - ServiceConfig(procs_per_replica=1, num_replicas=1), - ComputeAdvantages, - gamma=0.99, - lambda_=0.95, - ), - spawn_service( - ServiceConfig(procs_per_replica=1, num_replicas=1, with_gpus=True), - RefModel, - model_name=model, - ), - spawn_service( - ServiceConfig(procs_per_replica=1, num_replicas=1), - RewardActor, - reward_functions=[MathReward(), ThinkingReward()], - ), + # ( + # dataloader, + # policy, + # trainer, + # replay_buffer, + # compute_advantages, + # ref_model, + # reward_actor, + # ) = await asyncio.gather( + # spawn_service( + # ServiceConfig(procs_per_replica=1, num_replicas=1), + # DatasetActor, + # path="openai/gsm8k", + # config_name="main", + # split="train", + # streaming=True, + # ), + # spawn_service( + # ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1), + # Policy, + # config=PolicyConfig( + # worker_params=WorkerConfig(model=model), + # sampling_params=SamplingOverrides( + # num_samples=group_size, max_tokens=16 + # ), + # ), + # ), + # spawn_service( + # ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1), + # Trainer, + # learning_rate=1e-5, + # beta=0.1, + # model_name=model, + # store=store, + # ), + # spawn_service( + # ServiceConfig(procs_per_replica=1, num_replicas=1), + # ReplayBuffer, + # batch_size=4, + # max_policy_age=1, + # ), + # spawn_service( + # ServiceConfig(procs_per_replica=1, num_replicas=1), + # ComputeAdvantages, + # gamma=0.99, + # lambda_=0.95, + # ), + # spawn_service( + # ServiceConfig(procs_per_replica=1, num_replicas=1, with_gpus=True), + # RefModel, + # model_name=model, + # ), + # spawn_service( + # ServiceConfig(procs_per_replica=1, num_replicas=1), + # RewardActor, + # reward_functions=[MathReward(), ThinkingReward()], + # ), + # ) + + trainer = await spawn_service( + ServiceConfig(procs_per_replica=1, num_replicas=1, with_gpus=True), + Trainer, + learning_rate=1e-5, + beta=0.1, + model_name=model, + store=store, ) print("All services initialized successfully!") @@ -411,92 +421,92 @@ async def main(): await trainer.save_weights.choose() # ---- Core RL loops ---- # - async def continuous_rollouts(): - rollout_count = 0 - while True: - sample = await dataloader.__next__.choose() - if sample is None: - print("Dataloader is empty, exiting continuous rollout") - return - prompt, target = sample["question"], sample["answer"] - actions, policy_version = await policy.generate.choose(prompt) - episode = Episode( - episode_id=rollout_count, - prompt=prompt, - target=target, - policy_version=policy_version, - ) - for action in actions: - ref_logprobs = await ref_model.forward.choose(action.token_ids) - reward = await reward_actor.evaluate_response.choose( - prompt=prompt, response=action.text, target=target - ) - episode.add_group( - Group( - response=action.text, - ref_logprobs=ref_logprobs, - reward=reward, - ) - ) - - advantages = await compute_advantages.__call__.choose(episode.groups) - for advantage, group in zip(advantages, episode.groups): - group.advantage = advantage - - await replay_buffer.add.choose(episode) - - rollout_count += 1 - if rollout_count % 10 == 0: - avg_reward = sum(group.reward for group in episode.groups) / len( - episode.groups - ) - print( - f"Generated {rollout_count} rollouts w/ average reward {avg_reward}" - ) - logger.log("reward/rollout", avg_reward, rollout_count) - - async def continuous_training(): - on_policy_version = 0 - training_step = 0 - while True: - batch = await replay_buffer.sample.choose( - curr_policy_version=on_policy_version - ) - if batch is None: - await asyncio.sleep(0.1) - else: - training_result = await trainer.train_step.choose(batch) - training_step += 1 - if training_step % 10 == 0: - loss_value = training_result.get("loss", 0.0) - print( - f"Completed {training_step} training steps w/ loss: {loss_value}" - ) - logger.log("loss/training_step", loss_value, training_step) - print("Updating policy weights...") - await trainer.save_weights.choose() - - # print("Starting GRPO training loops...") - rollout_task = asyncio.create_task(continuous_rollouts()) - training_task = asyncio.create_task(continuous_training()) - - try: - await asyncio.gather(rollout_task, training_task) - except KeyboardInterrupt: - print("Training interrupted by user") - rollout_task.cancel() - training_task.cancel() - finally: - print("Shutting down...") - await asyncio.gather( - shutdown_service(policy), - shutdown_service(trainer), - shutdown_service(replay_buffer), - shutdown_service(dataloader), - shutdown_service(compute_advantages), - shutdown_service(ref_model), - shutdown_service(reward_actor), - ) + # async def continuous_rollouts(): + # rollout_count = 0 + # while True: + # sample = await dataloader.__next__.choose() + # if sample is None: + # print("Dataloader is empty, exiting continuous rollout") + # return + # prompt, target = sample["question"], sample["answer"] + # actions, policy_version = await policy.generate.choose(prompt) + # episode = Episode( + # episode_id=rollout_count, + # prompt=prompt, + # target=target, + # policy_version=policy_version, + # ) + # for action in actions: + # ref_logprobs = await ref_model.forward.choose(action.token_ids) + # reward = await reward_actor.evaluate_response.choose( + # prompt=prompt, response=action.text, target=target + # ) + # episode.add_group( + # Group( + # response=action.text, + # ref_logprobs=ref_logprobs, + # reward=reward, + # ) + # ) + + # advantages = await compute_advantages.__call__.choose(episode.groups) + # for advantage, group in zip(advantages, episode.groups): + # group.advantage = advantage + + # await replay_buffer.add.choose(episode) + + # rollout_count += 1 + # if rollout_count % 10 == 0: + # avg_reward = sum(group.reward for group in episode.groups) / len( + # episode.groups + # ) + # print( + # f"Generated {rollout_count} rollouts w/ average reward {avg_reward}" + # ) + # logger.log("reward/rollout", avg_reward, rollout_count) + + # async def continuous_training(): + # on_policy_version = 0 + # training_step = 0 + # while True: + # batch = await replay_buffer.sample.choose( + # curr_policy_version=on_policy_version + # ) + # if batch is None: + # await asyncio.sleep(0.1) + # else: + # training_result = await trainer.train_step.choose(batch) + # training_step += 1 + # if training_step % 10 == 0: + # loss_value = training_result.get("loss", 0.0) + # print( + # f"Completed {training_step} training steps w/ loss: {loss_value}" + # ) + # logger.log("loss/training_step", loss_value, training_step) + # print("Updating policy weights...") + # await trainer.save_weights.choose() + + # # print("Starting GRPO training loops...") + # rollout_task = asyncio.create_task(continuous_rollouts()) + # training_task = asyncio.create_task(continuous_training()) + + # try: + # await asyncio.gather(rollout_task, training_task) + # except KeyboardInterrupt: + # print("Training interrupted by user") + # rollout_task.cancel() + # training_task.cancel() + # finally: + # print("Shutting down...") + # await asyncio.gather( + # shutdown_service(policy), + # shutdown_service(trainer), + # shutdown_service(replay_buffer), + # shutdown_service(dataloader), + # shutdown_service(compute_advantages), + # shutdown_service(ref_model), + # shutdown_service(reward_actor), + # ) if __name__ == "__main__": diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 86d81be32..5cc600086 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -101,6 +101,7 @@ class Policy(PolicyInterface): lora_request: LoRARequest | None = None tokenization_kwargs: dict = field(default_factory=dict) policy_worker: "PolicyWorker" = None + store: MultiProcessStore | None = None def __post_init__(self): self._run_task: asyncio.Task | None = None