diff --git a/apps/grpo/main.py b/apps/grpo/main.py
new file mode 100644
index 000000000..80e6ee10a
--- /dev/null
+++ b/apps/grpo/main.py
@@ -0,0 +1,530 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import asyncio
+import time
+from dataclasses import dataclass
+from typing import Callable
+
+import torch
+from datasets import load_dataset
+from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig
+from forge.controller import ServiceConfig, spawn_service
+from forge.controller.actor import ForgeActor
+from monarch.actor import endpoint
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+
+def compute_sequence_logprobs(
+ model: torch.nn.Module,
+ input_ids: torch.Tensor,
+ attention_mask: torch.Tensor,
+ requires_grad: bool = True,
+) -> torch.Tensor:
+ context_manager = torch.enable_grad() if requires_grad else torch.no_grad()
+
+ with context_manager:
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
+ logits = outputs.logits
+
+ # Apply log softmax to get log probabilities
+ log_probs = torch.log_softmax(logits, dim=-1)
+
+ # Extract log probabilities for the actual tokens (excluding the first token for next-token prediction)
+ shifted_input_ids = input_ids[:, 1:] # Remove first token
+ shifted_log_probs = log_probs[:, :-1, :] # Remove last logit
+
+ # Gather log probabilities for actual tokens
+ token_log_probs = torch.gather(
+ shifted_log_probs, dim=-1, index=shifted_input_ids.unsqueeze(-1)
+ ).squeeze(-1)
+
+ # Sum log probabilities across sequence (masked by attention)
+ shifted_attention_mask = attention_mask[:, 1:]
+ sequence_log_probs = (token_log_probs * shifted_attention_mask).sum(dim=-1)
+
+ return sequence_log_probs
+
+
+@dataclass
+class Group:
+ response: str # The response text for tokenization
+ ref_logprobs: torch.Tensor
+ reward: float
+ advantage: float = 0.0
+
+
+class Episode:
+ """Episode container for GRPO rollouts."""
+
+ def __init__(self, episode_id: int, prompt: str, target: str, policy_version: int):
+ self.episode_id = episode_id
+ self.prompt = prompt
+ self.target = target
+ self.policy_version = policy_version
+ self.groups: list[Group] = []
+
+ def add_group(self, group: Group):
+ self.groups.append(group)
+
+
+class Trainer(ForgeActor):
+ """GRPO Trainer implementation for policy optimization."""
+
+ def __init__(
+ self,
+ learning_rate: float = 1e-5,
+ beta: float = 0.1,
+ model_name: str = "",
+ device: torch.device | None = None,
+ ):
+ super().__init__()
+ self.learning_rate = learning_rate
+ self.beta = beta # KL penalty coefficient
+ self.model_name = model_name
+
+ # Set device
+ if device is None:
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ else:
+ self.device = device
+
+ # Initialize model and tokenizer
+ self.model = AutoModelForCausalLM.from_pretrained(
+ model_name,
+ torch_dtype=torch.bfloat16,
+ trust_remote_code=True,
+ ).to(self.device)
+ self.model.train()
+
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
+ if self.tokenizer.pad_token is None:
+ self.tokenizer.pad_token = self.tokenizer.eos_token
+
+ # Initialize optimizer
+ self.optimizer = torch.optim.AdamW(
+ self.model.parameters(), lr=self.learning_rate
+ )
+
+ self.logger.info(f"Model initialized on {self.device}")
+
+ @endpoint
+ async def train_step(self, batch: list[Episode]):
+ total_loss = 0.0
+ num_groups_processed = 0
+
+ for episode in batch:
+ groups = episode.groups
+
+ # Collect all response texts and corresponding data
+ response_texts = []
+ ref_logprobs_list = []
+ advantages_list = []
+
+ for group in groups:
+ response_texts.append(group.response)
+ ref_logprobs_list.append(group.ref_logprobs)
+ advantages_list.append(group.advantage)
+
+ # Tokenize all responses in batch
+ tokenized = self.tokenizer(
+ response_texts,
+ padding=True,
+ truncation=True,
+ return_tensors="pt",
+ max_length=512, # Adjust based on your needs
+ )
+
+ input_ids = tokenized["input_ids"].to(self.device)
+ attention_mask = tokenized["attention_mask"].to(self.device)
+
+ # Compute current policy log probabilities using the model
+ current_logprobs = compute_sequence_logprobs(
+ self.model, input_ids, attention_mask, requires_grad=True
+ )
+
+ # Convert ref_logprobs and advantages to tensors
+ ref_logprobs_tensor = torch.stack(ref_logprobs_list).to(self.device)
+ advantages_tensor = torch.tensor(advantages_list, dtype=torch.float32).to(
+ self.device
+ )
+
+ # Compute GRPO loss components
+ # Ratio between current policy and reference policy
+ ratio = torch.exp(current_logprobs - ref_logprobs_tensor)
+
+ # Policy gradient loss weighted by advantages
+ pg_loss = -torch.mean(ratio * advantages_tensor)
+
+ # KL penalty to prevent policy from deviating too far from reference
+ kl_penalty = self.beta * torch.mean(
+ (current_logprobs - ref_logprobs_tensor) ** 2
+ )
+
+ # Total GRPO loss
+ loss = pg_loss + kl_penalty
+ total_loss += loss.item()
+ num_groups_processed += len(groups)
+
+ self.optimizer.zero_grad()
+ loss.backward()
+
+ # Gradient clipping (optional but recommended for stability)
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
+
+ self.optimizer.step()
+
+ avg_loss = total_loss / len(batch) if batch else 0.0
+
+ return {"loss": avg_loss, "groups_processed": num_groups_processed}
+
+ @endpoint
+ async def update_weights(self, policy_actor):
+ """Update policy model weights with trainer's current weights."""
+ # Time how long it takes to update weights
+ start_time = time.time()
+
+ # Set model to eval mode for weight extraction
+ self.model.eval()
+
+ # Extract current model state dict
+ model_state_dict = self.model.state_dict()
+
+ # Convert tensors to CPU for transfer (if they're on GPU)
+ cpu_state_dict = {}
+ for key, tensor in model_state_dict.items():
+ cpu_state_dict[key] = tensor.cpu() if tensor.is_cuda else tensor
+
+ # Update the policy actor's model weights
+ await policy_actor.update_model_weights.choose(cpu_state_dict)
+
+ # Set model back to training mode
+ self.model.train()
+
+ # Log the time taken
+ end_time = time.time()
+ self.logger.info(f"Updating weights took {end_time - start_time:.2f} seconds")
+
+
+def math_scoring_function(prompt: str, response: str, target: str) -> float:
+ """Function to score math correctness."""
+ import re
+
+ # Extract expected answer from target
+ expected_answer = (
+ float(target.strip())
+ if target.strip().replace(".", "").replace("-", "").isdigit()
+ else None
+ )
+
+ # Extract model answer from response
+ patterns = [
+ r"####\s*([+-]?\d+(?:\.\d+)?)", # GSM8K style answer format
+ r"(?:the\s+)?answer\s+is\s*([+-]?\d+(?:\.\d+)?)",
+ r"(?:answer:|result:)\s*([+-]?\d+(?:\.\d+)?)",
+ r"=\s*([+-]?\d+(?:\.\d+)?)\s*(?:\.|$)", # equals near end
+ r"\b([+-]?\d+(?:\.\d+)?)\s*(?:\.|$)", # number at end
+ r"([+-]?\d+(?:\.\d+)?)", # any number (fallback)
+ ]
+
+ model_answer = None
+ response_lower = response.lower().strip()
+ for pattern in patterns:
+ matches = re.findall(pattern, response_lower)
+ if matches:
+ model_answer = float(matches[-1])
+ break
+
+ if expected_answer is None or model_answer is None:
+ return 0.1 # Partial credit for attempting
+
+ # Check if answers match (with some tolerance for floating point)
+ if abs(expected_answer - model_answer) < 1e-6:
+ return 1.0 # Correct answer
+ else:
+ return 0.0 # Incorrect answer
+
+
+def thinking_scoring_function(prompt: str, response: str, target: str) -> float:
+ """Function to score thinking tag usage."""
+ # Check if response contains tags
+ if "" in response.lower() and "" in response.lower():
+ return 0.5
+ else:
+ return 0.0
+
+
+class RewardActor(ForgeActor):
+ """Reward actor that uses a list of scoring functions."""
+
+ def __init__(self, scoring_functions: list[Callable]):
+ super().__init__()
+ self.scoring_functions = scoring_functions
+
+ @endpoint
+ async def evaluate_response(self, prompt: str, response: str, target: str) -> float:
+ total_reward = 0.0
+ for scoring_fn in self.scoring_functions:
+ reward = scoring_fn(prompt, response, target)
+ total_reward += reward
+ return total_reward
+
+
+class ComputeAdvantages(ForgeActor):
+ """Compute advantages for GRPO using reward signals."""
+
+ def __init__(self, gamma: float = 0.99, lambda_: float = 0.95):
+ super().__init__()
+ self.gamma = gamma # Discount factor
+ self.lambda_ = lambda_ # GAE lambda parameter
+
+ @endpoint
+ async def __call__(self, groups: list[Group]) -> list[float]:
+ # Extract rewards from groups
+ rewards = [group.reward for group in groups]
+ num_groups = len(groups)
+
+ # For simplicity, use reward-to-go as advantages
+ # This is a valid advantage estimator: A(s,a) = Q(s,a) - V(s)
+ # where Q(s,a) ≈ reward-to-go and V(s) ≈ average reward
+
+ # Compute discounted reward-to-go for each step
+ reward_to_go = []
+ running_reward = 0.0
+
+ # Calculate discounted returns (reward-to-go)
+ for t in reversed(range(num_groups)):
+ running_reward = rewards[t] + self.gamma * running_reward
+ reward_to_go.insert(0, running_reward)
+
+ # Compute baseline (mean of rewards) and advantages
+ baseline = sum(rewards) / len(rewards) if rewards else 0.0
+ advantages = [rtg - baseline for rtg in reward_to_go]
+
+ # Normalize advantages to have zero mean and unit variance
+ if len(advantages) > 1:
+ mean_adv = sum(advantages) / len(advantages)
+ var_adv = sum((a - mean_adv) ** 2 for a in advantages) / len(advantages)
+ std_adv = (var_adv**0.5) if var_adv > 1e-8 else 1.0
+ advantages = [(a - mean_adv) / std_adv for a in advantages]
+
+ return advantages
+
+
+class RefModel(ForgeActor):
+ def __init__(self, model_name, device: torch.device | None = None):
+ super().__init__()
+ self.model_name = model_name
+
+ # Set device
+ if device is None:
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ else:
+ self.device = device
+
+ # Initialize model and tokenizer
+ self.model = AutoModelForCausalLM.from_pretrained(
+ model_name,
+ torch_dtype=torch.bfloat16,
+ trust_remote_code=True,
+ ).to(self.device)
+
+ # Set model to eval mode for reference computations
+ self.model.eval()
+
+ self.logger.info(f"Model initialized on {self.device}")
+
+ @endpoint
+ async def forward(self, token_ids: list[int]) -> torch.Tensor:
+ # Use provided token_ids directly
+ input_ids = (
+ torch.tensor(token_ids, dtype=torch.long).unsqueeze(0).to(self.device)
+ )
+ # Create attention mask of all 1s since we have actual tokens (no padding)
+ attention_mask = torch.ones_like(input_ids).to(self.device)
+
+ # Compute log probabilities using shared utility function
+ sequence_log_probs = compute_sequence_logprobs(
+ self.model, input_ids, attention_mask, requires_grad=False
+ )
+
+ return (
+ sequence_log_probs.squeeze()
+ ) # Remove batch dimension for single response
+
+
+class DatasetActor(ForgeActor):
+ """Actor wrapper for HuggingFace dataset to provide async interface."""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__()
+ self._setup_dataset(*args, **kwargs)
+
+ def _setup_dataset(self, *args, **kwargs):
+ def gsm8k_to_messages(sample):
+ question = sample["question"]
+ full_answer: str = sample["answer"]
+ answer = full_answer.split("#### ")[1]
+ return {"question": question, "answer": answer}
+
+ ds = load_dataset(*args, **kwargs)
+ ds = ds.map(gsm8k_to_messages)
+ ds = ds.shuffle()
+ self._iterator = iter(ds)
+
+ @endpoint
+ async def __next__(self) -> dict[str, str] | None:
+ try:
+ return next(self._iterator)
+ except StopIteration:
+ return None
+
+
+async def main():
+ """Main GRPO training loop with rollout and training processes."""
+ group_size = 1
+ model = "Qwen/Qwen3-1.7B"
+
+ # ---- Setup services ---- #
+ default_service_cfg = ServiceConfig(
+ procs_per_replica=1,
+ num_replicas=1,
+ )
+
+ policy = await spawn_service(
+ default_service_cfg,
+ Policy,
+ PolicyConfig(
+ num_workers=1,
+ worker_params=WorkerConfig(model=model),
+ sampling_params=SamplingOverrides(num_samples=group_size, max_tokens=16),
+ available_devices="3",
+ ),
+ )
+
+ trainer = await spawn_service(
+ default_service_cfg,
+ Trainer,
+ learning_rate=1e-5,
+ beta=0.1,
+ model_name=model,
+ device=torch.device("cuda:1"),
+ )
+
+ replay_buffer = await spawn_service(
+ default_service_cfg,
+ ReplayBuffer,
+ batch_size=4,
+ max_policy_age=1,
+ )
+
+ dataloader = await spawn_service(
+ default_service_cfg,
+ DatasetActor,
+ "openai/gsm8k",
+ "main",
+ split="train",
+ streaming=True,
+ )
+
+ compute_advantages = await spawn_service(
+ default_service_cfg,
+ ComputeAdvantages,
+ gamma=0.99,
+ lambda_=0.95,
+ )
+
+ ref_model = await spawn_service(
+ default_service_cfg,
+ RefModel,
+ model_name=model,
+ device=torch.device("cuda:2"),
+ )
+
+ reward_actor = await spawn_service(
+ default_service_cfg,
+ RewardActor,
+ scoring_functions=[math_scoring_function, thinking_scoring_function],
+ )
+
+ print("All services initialized successfully!")
+
+ # ---- Core RL loops ---- #
+ async def continuous_rollouts():
+ rollout_count = 0
+ # TODO: Move this into setup
+ asyncio.create_task(policy.run_processing.call())
+ while True:
+ sample = await dataloader.__next__.choose()
+ if sample is None:
+ print("Dataloader is empty, exiting continuous rollout")
+ return
+ prompt, target = sample["question"], sample["answer"]
+ version = 0 # await policy.get_current_version.choose()
+ episode = Episode(
+ episode_id=rollout_count,
+ prompt=prompt,
+ target=target,
+ policy_version=version,
+ )
+ actions = await policy.generate.choose(prompt)
+ for action in actions:
+ ref_logprobs = await ref_model.forward.choose(action.token_ids)
+ reward = await reward_actor.evaluate_response.choose(
+ prompt=prompt, response=action.text, target=target
+ )
+ episode.add_group(
+ Group(
+ response=action.text,
+ ref_logprobs=ref_logprobs,
+ reward=reward,
+ )
+ )
+
+ advantages = await compute_advantages.__call__.choose(episode.groups)
+ for advantage, group in zip(advantages, episode.groups):
+ group.advantage = advantage
+
+ await replay_buffer.add.choose(episode)
+
+ rollout_count += 1
+ if rollout_count % 10 == 0:
+ avg_reward = sum(group.reward for group in episode.groups) / len(
+ episode.groups
+ )
+ print(
+ f"Generated {rollout_count} rollouts w/ average reward {avg_reward}"
+ )
+
+ async def continuous_training():
+ training_step = 0
+ while True:
+ batch = await replay_buffer.sample.choose(curr_policy_version=0)
+ if batch is None:
+ await asyncio.sleep(0.1)
+ else:
+ training_result = await trainer.train_step.choose(batch)
+ training_step += 1
+ if training_step % 10 == 0:
+ print(f"Completed {training_step} training steps")
+ if training_result:
+ print(f"Latest loss: {training_result.get('loss', 'N/A')}")
+ # await trainer.update_weights(policy)
+
+ print("Starting GRPO training loops...")
+ rollout_task = asyncio.create_task(continuous_rollouts())
+ training_task = asyncio.create_task(continuous_training())
+
+ try:
+ await asyncio.gather(rollout_task, training_task)
+ except KeyboardInterrupt:
+ print("Training interrupted by user")
+ rollout_task.cancel()
+ training_task.cancel()
+
+
+if __name__ == "__main__":
+ asyncio.run(main())
diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py
index 32b4ba54a..5ac24c006 100644
--- a/src/forge/actors/policy.py
+++ b/src/forge/actors/policy.py
@@ -60,6 +60,7 @@ class SamplingOverrides:
num_samples: int
guided_decoding: bool = False
+ max_tokens: int = 512
@dataclass
@@ -87,6 +88,7 @@ class PolicyConfig:
num_workers: int
worker_params: WorkerConfig
sampling_params: SamplingOverrides
+ available_devices: str = None
@dataclass
@@ -102,6 +104,11 @@ class Policy(PolicyInterface):
@endpoint
async def setup(self):
# Set up policy_worker
+ self.available_devices = (
+ self.config.available_devices
+ if self.config.available_devices is not None
+ else ",".join(str(i) for i in range(torch.cuda.device_count()))
+ )
await self.spawn_workers()
self.request_id = 0
@@ -157,6 +164,7 @@ async def spawn_workers(self):
env={
"MASTER_ADDR": str(get_loopback_ip()),
"MASTER_PORT": str(get_open_port()),
+ "CUDA_VISIBLE_DEVICES": self.available_devices,
},
)
self.policy_worker = await self.worker_mesh.spawn(
@@ -200,7 +208,6 @@ async def generate(self, prompt: str, priority: int = 0) -> List[CompletionOutpu
if (num_samples := self.sampling_params.n) == 1:
self.output_processor.add_request(request, prompt_str, None, 0)
request, _ = self.preprocess_add_request(request)
-
request_fut = asyncio.Future()
self.requests[request_id] = (None, request_fut)
@@ -456,7 +463,6 @@ def convert_input(prompt=None, prompt_token_ids=None) -> Dict:
def get_default_sampling_params(vllm_config, overrides=None) -> SamplingParams:
default_params = vllm_config.model_config.get_diff_sampling_param()
- default_params["max_tokens"] = 512
if overrides is not None:
default_params |= overrides
if default_params:
diff --git a/src/forge/actors/replay_buffer.py b/src/forge/actors/replay_buffer.py
index 96146e862..d0e70e85f 100644
--- a/src/forge/actors/replay_buffer.py
+++ b/src/forge/actors/replay_buffer.py
@@ -11,7 +11,6 @@
from monarch.actor import endpoint
from forge.controller import ForgeActor
-from forge.types import Trajectory
@dataclass
@@ -24,37 +23,34 @@ class ReplayBuffer(ForgeActor):
@endpoint
async def setup(self) -> None:
- self.buffer: list[Trajectory] = []
+ self.buffer: list = []
if self.seed is None:
self.seed = random.randint(0, 2**32)
random.seed(self.seed)
self.sampler = random.sample
@endpoint
- async def add(self, trajectory: Trajectory) -> None:
- self.buffer.append(trajectory)
+ async def add(self, episode) -> None:
+ self.buffer.append(episode)
@endpoint
- async def sample(
- self, curr_policy_version: int, batch_size: int | None = None
- ) -> list[Trajectory] | None:
+ async def sample(self, curr_policy_version: int, batch_size: int | None = None):
"""Sample from the replay buffer.
Args:
curr_policy_version (int): The current policy version.
- batch_size (int, optional): Number of trajectories to sample. If none, defaults to batch size
+ batch_size (int, optional): Number of episodes to sample. If none, defaults to batch size
passed in at initialization.
Returns:
- A list of sampled trajectories or None if there are not enough trajectories in the buffer.
+ A list of sampled episodes or None if there are not enough episodes in the buffer.
"""
bsz = batch_size if batch_size is not None else self.batch_size
- # Evict old trajectories
+ # Evict old episodes
self._evict(curr_policy_version)
if bsz > len(self.buffer):
- print("Not enough trajectories in the buffer.")
return None
# TODO: Make this more efficient
@@ -62,12 +58,12 @@ async def sample(
sorted_idxs = sorted(
idx_to_sample, reverse=True
) # Sort in desc order to avoid shifting idxs
- sampled_trajectories = [self.buffer.pop(i) for i in sorted_idxs]
- return sampled_trajectories
+ sampled_episodes = [self.buffer.pop(i) for i in sorted_idxs]
+ return sampled_episodes
@endpoint
async def evict(self, curr_policy_version: int) -> None:
- """Evict trajectories from the replay buffer if they are too old based on the current policy version
+ """Evict episodes from the replay buffer if they are too old based on the current policy version
and the max policy age allowed.
Args:
@@ -83,17 +79,17 @@ def _evict(self, curr_policy_version: int) -> None:
]
@endpoint
- async def _getitem(self, idx: int) -> Trajectory:
+ async def _getitem(self, idx: int):
return self.buffer[idx]
@endpoint
async def _numel(self) -> int:
- """Number of elements (trajectories) in the replay buffer."""
+ """Number of elements (episodes) in the replay buffer."""
return len(self.buffer)
@endpoint
async def clear(self) -> None:
- """Clear the replay buffer immediately - dropping all trajectories."""
+ """Clear the replay buffer immediately - dropping all episodes."""
self.buffer.clear()
@endpoint
diff --git a/src/forge/controller/replica.py b/src/forge/controller/replica.py
index 9bb952679..8b1513d7f 100644
--- a/src/forge/controller/replica.py
+++ b/src/forge/controller/replica.py
@@ -13,11 +13,11 @@
from enum import Enum
from typing import Optional
-from monarch.actor import Actor, ActorError, ProcMesh
-
from forge.controller import get_proc_mesh
from forge.types import ProcessConfig
+from monarch.actor import Actor, ActorError, ProcMesh
+
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)