diff --git a/apps/grpo/main.py b/apps/grpo/main.py new file mode 100644 index 000000000..80e6ee10a --- /dev/null +++ b/apps/grpo/main.py @@ -0,0 +1,530 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import asyncio +import time +from dataclasses import dataclass +from typing import Callable + +import torch +from datasets import load_dataset +from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig +from forge.controller import ServiceConfig, spawn_service +from forge.controller.actor import ForgeActor +from monarch.actor import endpoint +from transformers import AutoModelForCausalLM, AutoTokenizer + + +def compute_sequence_logprobs( + model: torch.nn.Module, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + requires_grad: bool = True, +) -> torch.Tensor: + context_manager = torch.enable_grad() if requires_grad else torch.no_grad() + + with context_manager: + outputs = model(input_ids=input_ids, attention_mask=attention_mask) + logits = outputs.logits + + # Apply log softmax to get log probabilities + log_probs = torch.log_softmax(logits, dim=-1) + + # Extract log probabilities for the actual tokens (excluding the first token for next-token prediction) + shifted_input_ids = input_ids[:, 1:] # Remove first token + shifted_log_probs = log_probs[:, :-1, :] # Remove last logit + + # Gather log probabilities for actual tokens + token_log_probs = torch.gather( + shifted_log_probs, dim=-1, index=shifted_input_ids.unsqueeze(-1) + ).squeeze(-1) + + # Sum log probabilities across sequence (masked by attention) + shifted_attention_mask = attention_mask[:, 1:] + sequence_log_probs = (token_log_probs * shifted_attention_mask).sum(dim=-1) + + return sequence_log_probs + + +@dataclass +class Group: + response: str # The response text for tokenization + ref_logprobs: torch.Tensor + reward: float + advantage: float = 0.0 + + +class Episode: + """Episode container for GRPO rollouts.""" + + def __init__(self, episode_id: int, prompt: str, target: str, policy_version: int): + self.episode_id = episode_id + self.prompt = prompt + self.target = target + self.policy_version = policy_version + self.groups: list[Group] = [] + + def add_group(self, group: Group): + self.groups.append(group) + + +class Trainer(ForgeActor): + """GRPO Trainer implementation for policy optimization.""" + + def __init__( + self, + learning_rate: float = 1e-5, + beta: float = 0.1, + model_name: str = "", + device: torch.device | None = None, + ): + super().__init__() + self.learning_rate = learning_rate + self.beta = beta # KL penalty coefficient + self.model_name = model_name + + # Set device + if device is None: + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + self.device = device + + # Initialize model and tokenizer + self.model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch.bfloat16, + trust_remote_code=True, + ).to(self.device) + self.model.train() + + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + # Initialize optimizer + self.optimizer = torch.optim.AdamW( + self.model.parameters(), lr=self.learning_rate + ) + + self.logger.info(f"Model initialized on {self.device}") + + @endpoint + async def train_step(self, batch: list[Episode]): + total_loss = 0.0 + num_groups_processed = 0 + + for episode in batch: + groups = episode.groups + + # Collect all response texts and corresponding data + response_texts = [] + ref_logprobs_list = [] + advantages_list = [] + + for group in groups: + response_texts.append(group.response) + ref_logprobs_list.append(group.ref_logprobs) + advantages_list.append(group.advantage) + + # Tokenize all responses in batch + tokenized = self.tokenizer( + response_texts, + padding=True, + truncation=True, + return_tensors="pt", + max_length=512, # Adjust based on your needs + ) + + input_ids = tokenized["input_ids"].to(self.device) + attention_mask = tokenized["attention_mask"].to(self.device) + + # Compute current policy log probabilities using the model + current_logprobs = compute_sequence_logprobs( + self.model, input_ids, attention_mask, requires_grad=True + ) + + # Convert ref_logprobs and advantages to tensors + ref_logprobs_tensor = torch.stack(ref_logprobs_list).to(self.device) + advantages_tensor = torch.tensor(advantages_list, dtype=torch.float32).to( + self.device + ) + + # Compute GRPO loss components + # Ratio between current policy and reference policy + ratio = torch.exp(current_logprobs - ref_logprobs_tensor) + + # Policy gradient loss weighted by advantages + pg_loss = -torch.mean(ratio * advantages_tensor) + + # KL penalty to prevent policy from deviating too far from reference + kl_penalty = self.beta * torch.mean( + (current_logprobs - ref_logprobs_tensor) ** 2 + ) + + # Total GRPO loss + loss = pg_loss + kl_penalty + total_loss += loss.item() + num_groups_processed += len(groups) + + self.optimizer.zero_grad() + loss.backward() + + # Gradient clipping (optional but recommended for stability) + torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) + + self.optimizer.step() + + avg_loss = total_loss / len(batch) if batch else 0.0 + + return {"loss": avg_loss, "groups_processed": num_groups_processed} + + @endpoint + async def update_weights(self, policy_actor): + """Update policy model weights with trainer's current weights.""" + # Time how long it takes to update weights + start_time = time.time() + + # Set model to eval mode for weight extraction + self.model.eval() + + # Extract current model state dict + model_state_dict = self.model.state_dict() + + # Convert tensors to CPU for transfer (if they're on GPU) + cpu_state_dict = {} + for key, tensor in model_state_dict.items(): + cpu_state_dict[key] = tensor.cpu() if tensor.is_cuda else tensor + + # Update the policy actor's model weights + await policy_actor.update_model_weights.choose(cpu_state_dict) + + # Set model back to training mode + self.model.train() + + # Log the time taken + end_time = time.time() + self.logger.info(f"Updating weights took {end_time - start_time:.2f} seconds") + + +def math_scoring_function(prompt: str, response: str, target: str) -> float: + """Function to score math correctness.""" + import re + + # Extract expected answer from target + expected_answer = ( + float(target.strip()) + if target.strip().replace(".", "").replace("-", "").isdigit() + else None + ) + + # Extract model answer from response + patterns = [ + r"####\s*([+-]?\d+(?:\.\d+)?)", # GSM8K style answer format + r"(?:the\s+)?answer\s+is\s*([+-]?\d+(?:\.\d+)?)", + r"(?:answer:|result:)\s*([+-]?\d+(?:\.\d+)?)", + r"=\s*([+-]?\d+(?:\.\d+)?)\s*(?:\.|$)", # equals near end + r"\b([+-]?\d+(?:\.\d+)?)\s*(?:\.|$)", # number at end + r"([+-]?\d+(?:\.\d+)?)", # any number (fallback) + ] + + model_answer = None + response_lower = response.lower().strip() + for pattern in patterns: + matches = re.findall(pattern, response_lower) + if matches: + model_answer = float(matches[-1]) + break + + if expected_answer is None or model_answer is None: + return 0.1 # Partial credit for attempting + + # Check if answers match (with some tolerance for floating point) + if abs(expected_answer - model_answer) < 1e-6: + return 1.0 # Correct answer + else: + return 0.0 # Incorrect answer + + +def thinking_scoring_function(prompt: str, response: str, target: str) -> float: + """Function to score thinking tag usage.""" + # Check if response contains tags + if "" in response.lower() and "" in response.lower(): + return 0.5 + else: + return 0.0 + + +class RewardActor(ForgeActor): + """Reward actor that uses a list of scoring functions.""" + + def __init__(self, scoring_functions: list[Callable]): + super().__init__() + self.scoring_functions = scoring_functions + + @endpoint + async def evaluate_response(self, prompt: str, response: str, target: str) -> float: + total_reward = 0.0 + for scoring_fn in self.scoring_functions: + reward = scoring_fn(prompt, response, target) + total_reward += reward + return total_reward + + +class ComputeAdvantages(ForgeActor): + """Compute advantages for GRPO using reward signals.""" + + def __init__(self, gamma: float = 0.99, lambda_: float = 0.95): + super().__init__() + self.gamma = gamma # Discount factor + self.lambda_ = lambda_ # GAE lambda parameter + + @endpoint + async def __call__(self, groups: list[Group]) -> list[float]: + # Extract rewards from groups + rewards = [group.reward for group in groups] + num_groups = len(groups) + + # For simplicity, use reward-to-go as advantages + # This is a valid advantage estimator: A(s,a) = Q(s,a) - V(s) + # where Q(s,a) ≈ reward-to-go and V(s) ≈ average reward + + # Compute discounted reward-to-go for each step + reward_to_go = [] + running_reward = 0.0 + + # Calculate discounted returns (reward-to-go) + for t in reversed(range(num_groups)): + running_reward = rewards[t] + self.gamma * running_reward + reward_to_go.insert(0, running_reward) + + # Compute baseline (mean of rewards) and advantages + baseline = sum(rewards) / len(rewards) if rewards else 0.0 + advantages = [rtg - baseline for rtg in reward_to_go] + + # Normalize advantages to have zero mean and unit variance + if len(advantages) > 1: + mean_adv = sum(advantages) / len(advantages) + var_adv = sum((a - mean_adv) ** 2 for a in advantages) / len(advantages) + std_adv = (var_adv**0.5) if var_adv > 1e-8 else 1.0 + advantages = [(a - mean_adv) / std_adv for a in advantages] + + return advantages + + +class RefModel(ForgeActor): + def __init__(self, model_name, device: torch.device | None = None): + super().__init__() + self.model_name = model_name + + # Set device + if device is None: + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + self.device = device + + # Initialize model and tokenizer + self.model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch.bfloat16, + trust_remote_code=True, + ).to(self.device) + + # Set model to eval mode for reference computations + self.model.eval() + + self.logger.info(f"Model initialized on {self.device}") + + @endpoint + async def forward(self, token_ids: list[int]) -> torch.Tensor: + # Use provided token_ids directly + input_ids = ( + torch.tensor(token_ids, dtype=torch.long).unsqueeze(0).to(self.device) + ) + # Create attention mask of all 1s since we have actual tokens (no padding) + attention_mask = torch.ones_like(input_ids).to(self.device) + + # Compute log probabilities using shared utility function + sequence_log_probs = compute_sequence_logprobs( + self.model, input_ids, attention_mask, requires_grad=False + ) + + return ( + sequence_log_probs.squeeze() + ) # Remove batch dimension for single response + + +class DatasetActor(ForgeActor): + """Actor wrapper for HuggingFace dataset to provide async interface.""" + + def __init__(self, *args, **kwargs): + super().__init__() + self._setup_dataset(*args, **kwargs) + + def _setup_dataset(self, *args, **kwargs): + def gsm8k_to_messages(sample): + question = sample["question"] + full_answer: str = sample["answer"] + answer = full_answer.split("#### ")[1] + return {"question": question, "answer": answer} + + ds = load_dataset(*args, **kwargs) + ds = ds.map(gsm8k_to_messages) + ds = ds.shuffle() + self._iterator = iter(ds) + + @endpoint + async def __next__(self) -> dict[str, str] | None: + try: + return next(self._iterator) + except StopIteration: + return None + + +async def main(): + """Main GRPO training loop with rollout and training processes.""" + group_size = 1 + model = "Qwen/Qwen3-1.7B" + + # ---- Setup services ---- # + default_service_cfg = ServiceConfig( + procs_per_replica=1, + num_replicas=1, + ) + + policy = await spawn_service( + default_service_cfg, + Policy, + PolicyConfig( + num_workers=1, + worker_params=WorkerConfig(model=model), + sampling_params=SamplingOverrides(num_samples=group_size, max_tokens=16), + available_devices="3", + ), + ) + + trainer = await spawn_service( + default_service_cfg, + Trainer, + learning_rate=1e-5, + beta=0.1, + model_name=model, + device=torch.device("cuda:1"), + ) + + replay_buffer = await spawn_service( + default_service_cfg, + ReplayBuffer, + batch_size=4, + max_policy_age=1, + ) + + dataloader = await spawn_service( + default_service_cfg, + DatasetActor, + "openai/gsm8k", + "main", + split="train", + streaming=True, + ) + + compute_advantages = await spawn_service( + default_service_cfg, + ComputeAdvantages, + gamma=0.99, + lambda_=0.95, + ) + + ref_model = await spawn_service( + default_service_cfg, + RefModel, + model_name=model, + device=torch.device("cuda:2"), + ) + + reward_actor = await spawn_service( + default_service_cfg, + RewardActor, + scoring_functions=[math_scoring_function, thinking_scoring_function], + ) + + print("All services initialized successfully!") + + # ---- Core RL loops ---- # + async def continuous_rollouts(): + rollout_count = 0 + # TODO: Move this into setup + asyncio.create_task(policy.run_processing.call()) + while True: + sample = await dataloader.__next__.choose() + if sample is None: + print("Dataloader is empty, exiting continuous rollout") + return + prompt, target = sample["question"], sample["answer"] + version = 0 # await policy.get_current_version.choose() + episode = Episode( + episode_id=rollout_count, + prompt=prompt, + target=target, + policy_version=version, + ) + actions = await policy.generate.choose(prompt) + for action in actions: + ref_logprobs = await ref_model.forward.choose(action.token_ids) + reward = await reward_actor.evaluate_response.choose( + prompt=prompt, response=action.text, target=target + ) + episode.add_group( + Group( + response=action.text, + ref_logprobs=ref_logprobs, + reward=reward, + ) + ) + + advantages = await compute_advantages.__call__.choose(episode.groups) + for advantage, group in zip(advantages, episode.groups): + group.advantage = advantage + + await replay_buffer.add.choose(episode) + + rollout_count += 1 + if rollout_count % 10 == 0: + avg_reward = sum(group.reward for group in episode.groups) / len( + episode.groups + ) + print( + f"Generated {rollout_count} rollouts w/ average reward {avg_reward}" + ) + + async def continuous_training(): + training_step = 0 + while True: + batch = await replay_buffer.sample.choose(curr_policy_version=0) + if batch is None: + await asyncio.sleep(0.1) + else: + training_result = await trainer.train_step.choose(batch) + training_step += 1 + if training_step % 10 == 0: + print(f"Completed {training_step} training steps") + if training_result: + print(f"Latest loss: {training_result.get('loss', 'N/A')}") + # await trainer.update_weights(policy) + + print("Starting GRPO training loops...") + rollout_task = asyncio.create_task(continuous_rollouts()) + training_task = asyncio.create_task(continuous_training()) + + try: + await asyncio.gather(rollout_task, training_task) + except KeyboardInterrupt: + print("Training interrupted by user") + rollout_task.cancel() + training_task.cancel() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 32b4ba54a..5ac24c006 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -60,6 +60,7 @@ class SamplingOverrides: num_samples: int guided_decoding: bool = False + max_tokens: int = 512 @dataclass @@ -87,6 +88,7 @@ class PolicyConfig: num_workers: int worker_params: WorkerConfig sampling_params: SamplingOverrides + available_devices: str = None @dataclass @@ -102,6 +104,11 @@ class Policy(PolicyInterface): @endpoint async def setup(self): # Set up policy_worker + self.available_devices = ( + self.config.available_devices + if self.config.available_devices is not None + else ",".join(str(i) for i in range(torch.cuda.device_count())) + ) await self.spawn_workers() self.request_id = 0 @@ -157,6 +164,7 @@ async def spawn_workers(self): env={ "MASTER_ADDR": str(get_loopback_ip()), "MASTER_PORT": str(get_open_port()), + "CUDA_VISIBLE_DEVICES": self.available_devices, }, ) self.policy_worker = await self.worker_mesh.spawn( @@ -200,7 +208,6 @@ async def generate(self, prompt: str, priority: int = 0) -> List[CompletionOutpu if (num_samples := self.sampling_params.n) == 1: self.output_processor.add_request(request, prompt_str, None, 0) request, _ = self.preprocess_add_request(request) - request_fut = asyncio.Future() self.requests[request_id] = (None, request_fut) @@ -456,7 +463,6 @@ def convert_input(prompt=None, prompt_token_ids=None) -> Dict: def get_default_sampling_params(vllm_config, overrides=None) -> SamplingParams: default_params = vllm_config.model_config.get_diff_sampling_param() - default_params["max_tokens"] = 512 if overrides is not None: default_params |= overrides if default_params: diff --git a/src/forge/actors/replay_buffer.py b/src/forge/actors/replay_buffer.py index 96146e862..d0e70e85f 100644 --- a/src/forge/actors/replay_buffer.py +++ b/src/forge/actors/replay_buffer.py @@ -11,7 +11,6 @@ from monarch.actor import endpoint from forge.controller import ForgeActor -from forge.types import Trajectory @dataclass @@ -24,37 +23,34 @@ class ReplayBuffer(ForgeActor): @endpoint async def setup(self) -> None: - self.buffer: list[Trajectory] = [] + self.buffer: list = [] if self.seed is None: self.seed = random.randint(0, 2**32) random.seed(self.seed) self.sampler = random.sample @endpoint - async def add(self, trajectory: Trajectory) -> None: - self.buffer.append(trajectory) + async def add(self, episode) -> None: + self.buffer.append(episode) @endpoint - async def sample( - self, curr_policy_version: int, batch_size: int | None = None - ) -> list[Trajectory] | None: + async def sample(self, curr_policy_version: int, batch_size: int | None = None): """Sample from the replay buffer. Args: curr_policy_version (int): The current policy version. - batch_size (int, optional): Number of trajectories to sample. If none, defaults to batch size + batch_size (int, optional): Number of episodes to sample. If none, defaults to batch size passed in at initialization. Returns: - A list of sampled trajectories or None if there are not enough trajectories in the buffer. + A list of sampled episodes or None if there are not enough episodes in the buffer. """ bsz = batch_size if batch_size is not None else self.batch_size - # Evict old trajectories + # Evict old episodes self._evict(curr_policy_version) if bsz > len(self.buffer): - print("Not enough trajectories in the buffer.") return None # TODO: Make this more efficient @@ -62,12 +58,12 @@ async def sample( sorted_idxs = sorted( idx_to_sample, reverse=True ) # Sort in desc order to avoid shifting idxs - sampled_trajectories = [self.buffer.pop(i) for i in sorted_idxs] - return sampled_trajectories + sampled_episodes = [self.buffer.pop(i) for i in sorted_idxs] + return sampled_episodes @endpoint async def evict(self, curr_policy_version: int) -> None: - """Evict trajectories from the replay buffer if they are too old based on the current policy version + """Evict episodes from the replay buffer if they are too old based on the current policy version and the max policy age allowed. Args: @@ -83,17 +79,17 @@ def _evict(self, curr_policy_version: int) -> None: ] @endpoint - async def _getitem(self, idx: int) -> Trajectory: + async def _getitem(self, idx: int): return self.buffer[idx] @endpoint async def _numel(self) -> int: - """Number of elements (trajectories) in the replay buffer.""" + """Number of elements (episodes) in the replay buffer.""" return len(self.buffer) @endpoint async def clear(self) -> None: - """Clear the replay buffer immediately - dropping all trajectories.""" + """Clear the replay buffer immediately - dropping all episodes.""" self.buffer.clear() @endpoint diff --git a/src/forge/controller/replica.py b/src/forge/controller/replica.py index 9bb952679..8b1513d7f 100644 --- a/src/forge/controller/replica.py +++ b/src/forge/controller/replica.py @@ -13,11 +13,11 @@ from enum import Enum from typing import Optional -from monarch.actor import Actor, ActorError, ProcMesh - from forge.controller import get_proc_mesh from forge.types import ProcessConfig +from monarch.actor import Actor, ActorError, ProcMesh + logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG)