diff --git a/.github/workflows/unit_test.yaml b/.github/workflows/unit_test.yaml index ed4cbc5c3..d9e5dbe06 100644 --- a/.github/workflows/unit_test.yaml +++ b/.github/workflows/unit_test.yaml @@ -27,6 +27,11 @@ jobs: run: python -m pip install torch==2.9.0.dev20250826 --extra-index-url https://download.pytorch.org/whl/nightly/cpu - name: Install monarch run: python -m pip install monarch-no-torch==0.1.0.dev20250826 --find-links assets/ci + - name: Install torchstore + run: | + eval "$(ssh-agent -s)" + ssh-add - <<< '${{ secrets.FORGE_GITHUB_CI_FOR_TORCHSTORE }}' + python -m pip install git+ssh://git@github.com/meta-pytorch/torchstore.git - name: Install dependencies run: python -m pip install --no-build-isolation -e ".[dev]" - name: Run unit tests with coverage diff --git a/apps/grpo/main.py b/apps/grpo/main.py index dee7972d4..f08aab228 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -6,193 +6,220 @@ import asyncio import logging -import time +import uuid from dataclasses import dataclass -from typing import Callable +from typing import Any, Callable, Optional import torch +import torch.nn.functional as F from datasets import load_dataset from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig -from forge.actors.reference_actor import compute_sequence_logprobs, TitanRefModel from forge.actors.replay_buffer import ReplayBuffer from forge.controller.actor import ForgeActor from forge.controller.service import ServiceConfig, shutdown_service, spawn_service from forge.data.rewards import MathReward, ThinkingReward from forge.util.metric_logging import get_metric_logger from monarch.actor import endpoint -from torchtitan.config.job_config import Model as TitanJobModelConfig -from transformers import AutoModelForCausalLM, AutoTokenizer +from torch import nn +from transformers import AutoModelForCausalLM +from vllm.transformers_utils.tokenizer import get_tokenizer logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) -@dataclass -class Group: - response: str # The response text for tokenization - ref_logprobs: torch.Tensor - reward: float - advantage: float = 0.0 +def compute_logprobs( + logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0 +) -> torch.Tensor: + context_length = logits.shape[1] - input_ids.shape[1] + + # Truncate request logits and drop last + logits = logits[:, context_length - 1 : -1] + + # Compute logprobs + logprobs = torch.log_softmax(logits / temperature, dim=-1) + logprobs = torch.gather(logprobs, 2, input_ids.unsqueeze(-1)).squeeze(-1) + + return logprobs + + +class SimpleGRPOLoss(nn.Module): + """Simplified GRPO Loss for simplified single step updates + Copied from https://github.com/pytorch/torchtune/blob/main/torchtune/dev/grpo/loss.py. + """ + + def __init__(self, epsilon=0.1, beta=0.1): + super().__init__() + self.epsilon = epsilon + self.beta = beta + + def forward(self, logprobs, ref_logprobs, advantages, padding_mask): + per_token_kl = ( + torch.exp(ref_logprobs.detach() - logprobs) + - (ref_logprobs.detach() - logprobs) + - 1 + ) + per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages + per_token_loss = -(per_token_policy_loss - self.beta * per_token_kl) + loss = ( + (per_token_loss * padding_mask).sum(dim=1) + / (padding_mask.sum(dim=1) + 1e-8) + ).mean() + return loss +@dataclass class Episode: - """Episode container for GRPO rollouts.""" + # TODO: add adtional layer for multi-turn + episode_id: str + request: str + policy_version: int + pad_id: int + request_len: int + response_len: int + target: Optional[Any] = None + # processed data + response: Optional[str] = None + request_tokens: Optional[list[int]] = None + response_tokens: Optional[list[int]] = None + ref_logprobs: Optional[torch.Tensor] = None + reward: Optional[float] = None + advantage: Optional[float] = None + + @property + def request_tensor(self): + tensor = torch.tensor(self.request_tokens, dtype=torch.long) + if tensor.shape[0] < self.request_len: # left pad + diff = self.request_len - tensor.shape[0] + tensor = F.pad(tensor, (diff, 0), value=self.pad_id) + return tensor + + @property + def response_tensor(self): + tensor = torch.tensor(self.response_tokens, dtype=torch.long) + if tensor.shape[0] < self.response_len: # right pad + diff = self.response_len - tensor.shape[0] + tensor = F.pad(tensor, (0, diff), value=self.pad_id) + return tensor - def __init__(self, episode_id: int, prompt: str, target: str, policy_version: int): - self.episode_id = episode_id - self.prompt = prompt - self.target = target - self.policy_version = policy_version - self.groups: list[Group] = [] - def add_group(self, group: Group): - self.groups.append(group) +@dataclass +class Group: + group_id: str + episodes: list[Episode] + + @classmethod + def new_group( + cls, + group_id: int, + group_size: int, + request: str, + policy_version: int, + pad_id: int, + request_len: int, + response_len: int, + target: Any = None, + ): + episodes = [] + for i in range(group_size): + episodes.append( + Episode( + episode_id=str(uuid.uuid4()), + request=request, + policy_version=policy_version, + pad_id=pad_id, + request_len=request_len, + response_len=response_len, + target=target, + ) + ) + return cls(str(group_id), episodes) +@dataclass class Trainer(ForgeActor): """GRPO Trainer implementation for policy optimization.""" - def __init__( - self, - learning_rate: float = 1e-5, - beta: float = 0.1, - model_name: str = "", - device: torch.device | None = None, - ): - super().__init__() - self.learning_rate = learning_rate - self.beta = beta # KL penalty coefficient - self.model_name = model_name + model_name: str + learning_rate: float = 1e-5 + beta: float = 0.1 + epsilon: float = 0.1 + device: torch.device | None = None + @endpoint + def setup(self): # Set device - if device is None: + if self.device is None: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - else: - self.device = device - # Initialize model and tokenizer + # Initialize model self.model = AutoModelForCausalLM.from_pretrained( - model_name, - torch_dtype=torch.bfloat16, + self.model_name, + dtype=torch.bfloat16, trust_remote_code=True, ).to(self.device) self.model.train() - self.tokenizer = AutoTokenizer.from_pretrained(model_name) - if self.tokenizer.pad_token is None: - self.tokenizer.pad_token = self.tokenizer.eos_token - # Initialize optimizer self.optimizer = torch.optim.AdamW( self.model.parameters(), lr=self.learning_rate ) + self.optimizer.zero_grad() + + # Initialize loss + self.loss = SimpleGRPOLoss(self.epsilon, self.beta) self.logger.info(f"Model initialized on {self.device}") @endpoint async def train_step(self, batch: list[Episode]): - total_loss = 0.0 - num_groups_processed = 0 - - for episode in batch: - groups = episode.groups - - # Collect all response texts and corresponding data - response_texts = [] - ref_logprobs_list = [] - advantages_list = [] - - for group in groups: - response_texts.append(group.response) - ref_logprobs_list.append(group.ref_logprobs) - advantages_list.append(group.advantage) - - # Tokenize all responses in batch - tokenized = self.tokenizer( - response_texts, - padding=True, - truncation=True, - return_tensors="pt", - max_length=512, # Adjust based on your needs - ) - - input_ids = tokenized["input_ids"].to(self.device) - attention_mask = tokenized["attention_mask"].to(self.device) - - # Compute current policy log probabilities using the model - current_logprobs = compute_sequence_logprobs( - self.model, input_ids, attention_mask, requires_grad=True - ) + pad_id = batch[0].pad_id - # Convert ref_logprobs and advantages to tensors - ref_logprobs_tensor = torch.stack(ref_logprobs_list).to(self.device) - advantages_tensor = torch.tensor(advantages_list, dtype=torch.float32).to( - self.device - ) + # prepare batch + request = [e.request_tensor for e in batch] + request = torch.stack(request).to(self.device) # [b x s] - # Compute GRPO loss components - # Ratio between current policy and reference policy - ratio = torch.exp(current_logprobs - ref_logprobs_tensor) + response = [e.response_tensor for e in batch] + response = torch.stack(response).to(self.device) # [b x s] - # Policy gradient loss weighted by advantages - pg_loss = -torch.mean(ratio * advantages_tensor) + ref_logprobs = [e.ref_logprobs for e in batch] + ref_logprobs = torch.stack(ref_logprobs).to(self.device).squeeze() # [b x s] - # KL penalty to prevent policy from deviating too far from reference - kl_penalty = self.beta * torch.mean( - (current_logprobs - ref_logprobs_tensor) ** 2 - ) + advantages = [e.advantage for e in batch] + advantages = torch.tensor(advantages).to(self.device).unsqueeze(-1) # [b x 1] + del batch - # Total GRPO loss - loss = pg_loss + kl_penalty - total_loss += loss.item() - num_groups_processed += len(groups) + # compute policy logprobs + input_ids = torch.cat([request, response], dim=1) + mask = input_ids != pad_id + logits = self.model(input_ids=input_ids, attention_mask=mask).logits + logprobs = compute_logprobs(logits, response) + del logits - self.optimizer.zero_grad() - loss.backward() + # compute loss + mask = response != pad_id + loss = self.loss(logprobs, ref_logprobs, advantages, mask) - # Gradient clipping (optional but recommended for stability) - torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) + self.optimizer.zero_grad() + loss.backward() - self.optimizer.step() + # # Gradient clipping (optional but recommended for stability) + # torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) - avg_loss = total_loss / len(batch) if batch else 0.0 + self.optimizer.step() - return {"loss": avg_loss, "groups_processed": num_groups_processed} + return {"loss": loss.item()} @endpoint - async def update_weights(self, policy_actor): - """Update policy model weights with trainer's current weights.""" - # Time how long it takes to update weights - start_time = time.time() - - # Set model to eval mode for weight extraction - self.model.eval() - - # Extract current model state dict - model_state_dict = self.model.state_dict() - - # Convert tensors to CPU for transfer (if they're on GPU) - cpu_state_dict = {} - for key, tensor in model_state_dict.items(): - cpu_state_dict[key] = tensor.cpu() if tensor.is_cuda else tensor - - # Update the policy actor's model weights - await policy_actor.update_model_weights.choose(cpu_state_dict) - - # Set model back to training mode - self.model.train() - - # Log the time taken - end_time = time.time() - self.logger.info(f"Updating weights took {end_time - start_time:.2f} seconds") + async def push_weights(self): + pass +@dataclass class RewardActor(ForgeActor): """Reward actor that uses a list of scoring functions.""" - def __init__(self, reward_functions: list[Callable]): - super().__init__() - self.reward_functions = reward_functions + reward_functions: list[Callable] @endpoint async def evaluate_response(self, prompt: str, response: str, target: str) -> float: @@ -206,76 +233,105 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl class ComputeAdvantages(ForgeActor): """Compute advantages for GRPO using reward signals.""" - def __init__(self, gamma: float = 0.99, lambda_: float = 0.95): - super().__init__() - self.gamma = gamma # Discount factor - self.lambda_ = lambda_ # GAE lambda parameter - @endpoint - async def __call__(self, groups: list[Group]) -> list[float]: - # Extract rewards from groups - rewards = [group.reward for group in groups] - num_groups = len(groups) + async def compute(self, group: Group) -> list[float]: + # TODO: add batch processing + rewards = torch.Tensor([[e.reward for e in group.episodes]]) + mean = rewards.mean(1, keepdim=True) + std = rewards.std(1, keepdim=True) + + # if std is nan, return 0s. Remove this before shipping + if std.isnan().any(): + advantages = torch.zeros_like(rewards) + else: + advantages = (rewards - mean) / (std + 1e-4) + + x = advantages.squeeze(0).tolist() + return x + + +class RefModel(ForgeActor): + def __init__(self, model_name, device: torch.device | None = None): + super().__init__() + self.model_name = model_name - # For simplicity, use reward-to-go as advantages - # This is a valid advantage estimator: A(s,a) = Q(s,a) - V(s) - # where Q(s,a) ≈ reward-to-go and V(s) ≈ average reward + if device is None: + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + self.device = device - # Compute discounted reward-to-go for each step - reward_to_go = [] - running_reward = 0.0 + self.model = AutoModelForCausalLM.from_pretrained( + model_name, + dtype=torch.bfloat16, + trust_remote_code=True, + ).to(self.device) + self.model.eval() - # Calculate discounted returns (reward-to-go) - for t in reversed(range(num_groups)): - running_reward = rewards[t] + self.gamma * running_reward - reward_to_go.insert(0, running_reward) + self.logger.info(f"Model initialized on {self.device}") - # Compute baseline (mean of rewards) and advantages - baseline = sum(rewards) / len(rewards) if rewards else 0.0 - advantages = [rtg - baseline for rtg in reward_to_go] + @endpoint + async def forward(self, episode: Episode) -> torch.Tensor: + req, res = episode.request_tensor, episode.response_tensor + input_ids = torch.cat([req, res]).to(self.device).unsqueeze(0) + mask = input_ids != episode.pad_id - # Normalize advantages to have zero mean and unit variance - if len(advantages) > 1: - mean_adv = sum(advantages) / len(advantages) - var_adv = sum((a - mean_adv) ** 2 for a in advantages) / len(advantages) - std_adv = (var_adv**0.5) if var_adv > 1e-8 else 1.0 - advantages = [(a - mean_adv) / std_adv for a in advantages] + with torch.inference_mode(): + logits = self.model(input_ids=input_ids, attention_mask=mask).logits - return advantages + input_ids = input_ids[:, len(req) :] + return compute_logprobs(logits, input_ids) +@dataclass class DatasetActor(ForgeActor): """Actor wrapper for HuggingFace dataset to provide async interface.""" - def __init__( - self, path: str, config_name: str, split: str, streaming: bool, **kwargs - ): - super().__init__() + path: str + revision: str + data_split: str + streaming: bool + model: str - def gsm8k_to_messages(sample): - question = sample["question"] - full_answer: str = sample["answer"] - answer = full_answer.split("#### ")[1] - return {"question": question, "answer": answer} + @endpoint + def setup(self): + self.tokenizer = get_tokenizer(self.model) + + def gsm8k_transform(sample): + request: str = sample["question"] + formatted_request = self.tokenizer.apply_chat_template( + [{"role": "user", "content": request}], + tokenize=False, + add_generation_prompt=True, + ) + target: str = sample["answer"] + formatted_target = target.split("#### ")[1] + return {"request": formatted_request, "target": formatted_target} - ds = load_dataset(path, config_name, split=split, streaming=streaming) - ds = ds.map(gsm8k_to_messages) + ds = load_dataset( + self.path, self.revision, split=self.data_split, streaming=self.streaming + ) + ds = ds.map(gsm8k_transform) ds = ds.shuffle() self._iterator = iter(ds) @endpoint - async def __next__(self) -> dict[str, str] | None: + async def sample(self) -> dict[str, str] | None: try: return next(self._iterator) except StopIteration: return None + @endpoint + async def pad_token(self): + return self.tokenizer.pad_token_id + async def main(): """Main GRPO training loop with rollout and training processes.""" - group_size = 1 - model = "Qwen/Qwen3-0.6B" - titan_model = TitanJobModelConfig(name="qwen3", flavor="0.6B") + group_size = 4 + model = "Qwen/Qwen3-1.7B-Base" + max_req_tokens = 512 + max_res_tokens = 128 # ---- Setup WandB Logger ---- # logger = get_metric_logger( @@ -298,9 +354,10 @@ async def main(): ServiceConfig(procs_per_replica=1, num_replicas=1), DatasetActor, path="openai/gsm8k", - config_name="main", - split="train", + revision="main", + data_split="train", streaming=True, + model=model, ), spawn_service( ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1), @@ -308,7 +365,7 @@ async def main(): config=PolicyConfig( worker_params=WorkerConfig(model=model), sampling_params=SamplingOverrides( - num_samples=group_size, max_tokens=16 + n=group_size, max_tokens=max_res_tokens ), ), ), @@ -316,7 +373,6 @@ async def main(): ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1), Trainer, learning_rate=1e-5, - beta=0.1, model_name=model, ), spawn_service( @@ -328,13 +384,11 @@ async def main(): spawn_service( ServiceConfig(procs_per_replica=1, num_replicas=1), ComputeAdvantages, - gamma=0.99, - lambda_=0.95, ), spawn_service( ServiceConfig(procs_per_replica=1, num_replicas=1, with_gpus=True), - TitanRefModel, - model=titan_model, + RefModel, + model_name=model, ), spawn_service( ServiceConfig(procs_per_replica=1, num_replicas=1), @@ -348,49 +402,43 @@ async def main(): # ---- Core RL loops ---- # async def continuous_rollouts(): rollout_count = 0 + pad_id = await dataloader.pad_token.choose() while True: - sample = await dataloader.__next__.choose() + sample = await dataloader.sample.choose() if sample is None: print("Dataloader is empty, exiting continuous rollout") return - prompt, target = sample["question"], sample["answer"] + prompt, target = sample["request"], sample["target"] version = 0 # await policy.get_current_version.choose() - episode = Episode( - episode_id=rollout_count, - prompt=prompt, - target=target, + group = Group.new_group( + group_id=rollout_count, + group_size=group_size, + request=prompt, policy_version=version, + pad_id=pad_id, + request_len=max_req_tokens, + response_len=max_res_tokens, + target=target, ) - responses = await policy.generate.choose(prompt) - actions = responses.outputs - for action in actions: - request_tokens = responses.prompt_token_ids - response_tokens = action.token_ids - ref_logprobs = await ref_model.forward.choose( - request=request_tokens, response=response_tokens - ) - reward = await reward_actor.evaluate_response.choose( - prompt=prompt, response=action.text, target=target - ) - episode.add_group( - Group( - response=action.text, - ref_logprobs=ref_logprobs, - reward=reward, - ) - ) - advantages = await compute_advantages.__call__.choose(episode.groups) - for advantage, group in zip(advantages, episode.groups): - group.advantage = advantage + responses = await policy.generate.choose(prompt) - await replay_buffer.add.choose(episode) + for episode, response in zip(group.episodes, responses.outputs): + episode.request_tokens = responses.prompt_token_ids + episode.response_tokens = response.token_ids + assert len(response.token_ids) <= max_res_tokens + episode.ref_logprobs = await ref_model.forward.choose(episode) + episode.reward = await reward_actor.evaluate_response.choose( + prompt=prompt, response=response.text, target=target + ) + advantages = await compute_advantages.compute.choose(group) + for episode, advantage in zip(group.episodes, advantages): + episode.advantage = advantage + await replay_buffer.add.choose(episode) rollout_count += 1 if rollout_count % 10 == 0: - avg_reward = sum(group.reward for group in episode.groups) / len( - episode.groups - ) + avg_reward = sum(e.reward for e in group.episodes) / len(group.episodes) print( f"Generated {rollout_count} rollouts w/ average reward {avg_reward}" ) @@ -414,6 +462,7 @@ async def continuous_training(): # await trainer.update_weights(policy) print("Starting GRPO training loops...") + # TODO: Start multiple rollouts once all serivces support it rollout_task = asyncio.create_task(continuous_rollouts()) training_task = asyncio.create_task(continuous_training()) diff --git a/apps/vllm/main.py b/apps/vllm/main.py index 6f98a512e..0e438a28a 100644 --- a/apps/vllm/main.py +++ b/apps/vllm/main.py @@ -17,6 +17,7 @@ from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig from forge.controller.service import ServiceConfig, shutdown_service, spawn_service from vllm.outputs import RequestOutput +from vllm.transformers_utils.tokenizer import get_tokenizer async def main(): @@ -32,6 +33,13 @@ async def main(): else: prompt = args.prompt + # format prompt + tokenizer = get_tokenizer(policy_config.worker_params.model) + messages = [{"role": "user", "content": prompt}] + prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + # Run the policy await run_vllm(service_config, policy_config, prompt) @@ -41,7 +49,7 @@ def parse_args() -> Namespace: parser.add_argument( "--model", type=str, - default="meta-llama/Llama-3.1-8B-Instruct", + default="Qwen/Qwen3-1.7B", # "meta-llama/Llama-3.1-8B-Instruct", help="Model to use", ) parser.add_argument( @@ -68,8 +76,9 @@ def get_configs(args: Namespace) -> (PolicyConfig, ServiceConfig): ) sampling_params = SamplingOverrides( - num_samples=args.num_samples, + n=args.num_samples, guided_decoding=args.guided_decoding, + max_tokens=16, ) policy_config = PolicyConfig( diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 798d2d2d0..77a3f0942 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -13,12 +13,6 @@ from typing import Dict, List import torch - -from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh - -from forge.data.sharding import VLLMSharding -from forge.interfaces import Policy as PolicyInterface -from forge.types import ProcessConfig from monarch.actor import current_rank, endpoint, ProcMesh from torchstore import MultiProcessStore from torchstore._state_dict_utils import DELIM @@ -43,6 +37,12 @@ from vllm.v1.structured_output import StructuredOutputManager from vllm.worker.worker_base import WorkerWrapperBase +from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh + +from forge.data.sharding import VLLMSharding +from forge.interfaces import Policy as PolicyInterface +from forge.types import ProcessConfig + logger = logging.getLogger(__name__) @@ -57,14 +57,21 @@ class SamplingOverrides: subset Args: - num_samples: Number of samples to generate. + n: Number of samples to generate. guided_decoding: Whether to use guided decoding. + max_tokens: Maximum number of tokens to generate. """ - num_samples: int + n: int guided_decoding: bool = False max_tokens: int = 512 + def __post_init__(self): + gd_params = None + if self.guided_decoding: + gd_params = GuidedDecodingParams(choice=["Positive", "Negative"]) + self.guided_decoding = gd_params + @dataclass class WorkerConfig: @@ -174,17 +181,8 @@ async def setup(self): self.vllm_args = await self.policy_worker.get_vllm_args.choose() # Setup sampling params - sampling_overrides = self.config.sampling_params - overrides = { - "n": sampling_overrides.num_samples, - "guided_decoding": ( - GuidedDecodingParams(choice=["Positive", "Negative"]) - if sampling_overrides.guided_decoding - else None - ), - } self.sampling_params = get_default_sampling_params( - self.vllm_args, overrides=overrides + self.vllm_args, overrides=asdict(self.config.sampling_params) ) # Setup processors @@ -228,10 +226,11 @@ async def generate(self, prompt: str, priority: int = 0) -> List[CompletionOutpu request_id = str(self.request_id) # implement from a counter # Wraps prompt into a dict - prompt: Dict[str, str] = convert_input(prompt) + prompt: Dict[str, str] = convert_input(prompt=prompt) # truncate prmpt tokenization_kwargs = self.tokenization_kwargs or {} + # TODO: add truncation support https://github.com/vllm-project/vllm/issues/4507 truncate_prompt_tokens = self.sampling_params.truncate_prompt_tokens _validate_truncation_size( self.vllm_args.model_config.max_model_len,