From e6b7692a7fc24d184aefb0822ee3f881db0153de Mon Sep 17 00:00:00 2001 From: Philip Bontrager Date: Fri, 29 Aug 2025 14:22:44 -0700 Subject: [PATCH 01/31] first changes --- apps/grpo/main.py | 145 +++++++++++++++++++++++++++---------- apps/vllm/main.py | 8 +- src/forge/actors/policy.py | 32 ++++---- 3 files changed, 130 insertions(+), 55 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 7fd10736f..e5608a83c 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -5,7 +5,9 @@ # LICENSE file in the root directory of this source tree. import asyncio +import copy import time +import uuid from dataclasses import dataclass from typing import Callable @@ -19,6 +21,7 @@ from forge.util.metric_logging import get_metric_logger from monarch.actor import endpoint from transformers import AutoModelForCausalLM, AutoTokenizer +from vllm.transformers_utils.tokenizer import get_tokenizer def compute_sequence_logprobs( @@ -52,26 +55,76 @@ def compute_sequence_logprobs( return sequence_log_probs +Role = Literal[ + "system", # Origin is system prompt + "user", # Origin is user + "assistant", # Origin is the model output + "agent", # Origin is generated + "tool", # Origin is return from a tool call +] + + @dataclass -class Group: - response: str # The response text for tokenization - ref_logprobs: torch.Tensor - reward: float - advantage: float = 0.0 +class Message: + role: Role + content: str +@dataclass class Episode: - """Episode container for GRPO rollouts.""" + # TODO: add adtional layer for multi-turn + episode_id: str + request: list[Message] + policy_version: int + # processed data + response: list[Message] + request_tokens: Optional[torch.Tensor] + response_tokens: Optional[torch.Tensor] + ref_logprobs: Optional[torch.Tensor] = None + reward: Optional[float] = None + advantage: Optional[float] = None + policy_version: Optional[int] = None + + +@dataclass +class Group: + group_id: str + episodes: list[Episode] + + @classmethod + def new_group( + cls, group_id: int, group_size: int, request: list[Message], policy_version: int + ): + episodes = [] + for i in range(group_size): + Episode( + episode_id=str(uuid.uuid4()), + request=copy.deepcopy(messages), + policy_version=policy_version, + ) + return cls(group_id, episodes) - def __init__(self, episode_id: int, prompt: str, target: str, policy_version: int): - self.episode_id = episode_id - self.prompt = prompt - self.target = target - self.policy_version = policy_version - self.groups: list[Group] = [] - def add_group(self, group: Group): - self.groups.append(group) +# @dataclass +# class Group: +# response: str # The response text for tokenization +# ref_logprobs: torch.Tensor +# reward: float +# advantage: float = 0.0 +# +# +# class Episode: +# """Episode container for GRPO rollouts.""" +# +# def __init__(self, episode_id: int, prompt: str, target: str, policy_version: int): +# self.episode_id = episode_id +# self.prompt = prompt +# self.target = target +# self.policy_version = policy_version +# self.groups: list[Group] = [] +# +# def add_group(self, group: Group): +# self.groups.append(group) class Trainer(ForgeActor): @@ -237,7 +290,7 @@ def __init__(self, gamma: float = 0.99, lambda_: float = 0.95): self.lambda_ = lambda_ # GAE lambda parameter @endpoint - async def __call__(self, groups: list[Group]) -> list[float]: + async def compute(self, groups: list[Group]) -> list[float]: # Extract rewards from groups rewards = [group.reward for group in groups] num_groups = len(groups) @@ -311,27 +364,27 @@ async def forward(self, token_ids: list[int]) -> torch.Tensor: ) # Remove batch dimension for single response +@dataclass class DatasetActor(ForgeActor): """Actor wrapper for HuggingFace dataset to provide async interface.""" - def __init__(self, *args, **kwargs): - super().__init__() - self._setup_dataset(*args, **kwargs) + path: str + name: str + split: str + streaming: bool + transform: Callable - def _setup_dataset(self, *args, **kwargs): - def gsm8k_to_messages(sample): - question = sample["question"] - full_answer: str = sample["answer"] - answer = full_answer.split("#### ")[1] - return {"question": question, "answer": answer} - - ds = load_dataset(*args, **kwargs) - ds = ds.map(gsm8k_to_messages) + @endpoint + def setup(self): + ds = load_dataset( + self.path, self.name, split=self.split, streaming=self.streaming + ) + ds = ds.map(self.transform) ds = ds.shuffle() self._iterator = iter(ds) @endpoint - async def __next__(self) -> dict[str, str] | None: + async def sample(self) -> dict[str, str] | None: try: return next(self._iterator) except StopIteration: @@ -343,6 +396,14 @@ async def main(): group_size = 1 model = "Qwen/Qwen3-1.7B" + # ---- Setup data transform ---- # + def gsm8k_to_messages(sample): + question = content = sample["question"] + full_answer: str = sample["answer"] + answer = full_answer.split("#### ")[1] + return + return [Message("user", question), Message("assistant", answer)] + # ---- Setup WandB Logger ---- # logger = get_metric_logger( "wandb", @@ -362,7 +423,7 @@ async def main(): PolicyConfig( num_workers=1, worker_params=WorkerConfig(model=model), - sampling_params=SamplingOverrides(num_samples=group_size, max_tokens=16), + sampling_params=SamplingOverrides(n=group_size, max_tokens=16), available_devices="3", ), ) @@ -390,6 +451,7 @@ async def main(): "main", split="train", streaming=True, + transform=gsm8k_to_messages, ) compute_advantages = await spawn_service( @@ -413,6 +475,8 @@ async def main(): ) print("All services initialized successfully!") + tokenizer = get_tokenizer(model) + print("philip5:", tokenizer.encode("A fake response")) # ---- Core RL loops ---- # async def continuous_rollouts(): @@ -420,21 +484,23 @@ async def continuous_rollouts(): # TODO: Move this into setup asyncio.create_task(policy.run_processing.call()) while True: - sample = await dataloader.__next__.choose() + sample = await dataloader.sample.choose() if sample is None: print("Dataloader is empty, exiting continuous rollout") return - prompt, target = sample["question"], sample["answer"] + prompt, target = sample version = 0 # await policy.get_current_version.choose() - episode = Episode( - episode_id=rollout_count, - prompt=prompt, - target=target, + group = Group.new_group( + group_id=rollout_count, + group_size=group_size, + request=sample, policy_version=version, ) - actions = await policy.generate.choose(prompt) - for action in actions: - ref_logprobs = await ref_model.forward.choose(action.token_ids) + responses = await policy.generate.choose(prompt) + for episode, response in zip(group.episodes, responses.outputs): + + episode.tokens = response.prompt_token_ids + response.token_ids + ref_logprobs = await ref_model.forward.choose(episode.tokens) reward = await reward_actor.evaluate_response.choose( prompt=prompt, response=action.text, target=target ) @@ -446,7 +512,7 @@ async def continuous_rollouts(): ) ) - advantages = await compute_advantages.__call__.choose(episode.groups) + advantages = await compute_advantages.compute.choose(episode.groups) for advantage, group in zip(advantages, episode.groups): group.advantage = advantage @@ -480,6 +546,7 @@ async def continuous_training(): # await trainer.update_weights(policy) print("Starting GRPO training loops...") + # TODO: Start multiple rollouts once all serivces support it rollout_task = asyncio.create_task(continuous_rollouts()) training_task = asyncio.create_task(continuous_training()) diff --git a/apps/vllm/main.py b/apps/vllm/main.py index 111a8d5e3..b1df7c5eb 100644 --- a/apps/vllm/main.py +++ b/apps/vllm/main.py @@ -18,7 +18,7 @@ from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig from forge.controller.service import ServiceConfig from forge.controller.spawn import spawn_service -from vllm.outputs import CompletionOutput +from vllm.outputs import CompletionOutput, RequestOutput async def main(): @@ -68,8 +68,9 @@ def get_configs(args: Namespace) -> (PolicyConfig, ServiceConfig): ) sampling_params = SamplingOverrides( - num_samples=args.num_samples, + n=args.num_samples, guided_decoding=args.guided_decoding, + max_tokens=16, ) policy_config = PolicyConfig( @@ -89,7 +90,8 @@ async def run_vllm(service_config: ServiceConfig, config: PolicyConfig, prompt: processing_task = asyncio.create_task(policy.run_processing.call()) print("Requesting generation...") - responses: List[CompletionOutput] = await policy.generate.choose(prompt=prompt) + request_output: RequestOutput = await policy.generate.choose(prompt=prompt) + responses: List[CompletionOutput] = request_output.outputs print("\nGeneration Results:") print("=" * 80) diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 5ac24c006..26b3af2f6 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -54,14 +54,21 @@ class SamplingOverrides: subset Args: - num_samples: Number of samples to generate. + n: Number of samples to generate. guided_decoding: Whether to use guided decoding. + max_tokens: Maximum number of tokens to generate. """ - num_samples: int + n: int guided_decoding: bool = False max_tokens: int = 512 + def __post_init__(self): + gd_params = None + if self.guided_decoding: + gd_params = GuidedDecodingParams(choice=["Positive", "Negative"]) + self.guided_decoding = gd_params + @dataclass class WorkerConfig: @@ -116,22 +123,14 @@ async def setup(self): self.vllm_args = await self.policy_worker.get_vllm_args.choose() # Setup sampling params - sampling_overrides = self.config.sampling_params - overrides = { - "n": sampling_overrides.num_samples, - "guided_decoding": ( - GuidedDecodingParams(choice=["Positive", "Negative"]) - if sampling_overrides.guided_decoding - else None - ), - } self.sampling_params = get_default_sampling_params( - self.vllm_args, overrides=overrides + self.vllm_args, overrides=asdict(self.config.sampling_params) ) # Setup processors # TODO: move all processing to the Environment # TODO: add support for `log_stats` and `mm_registry` + print("philip0:", self.vllm_args.model_config) tokenizer = init_tokenizer_from_configs( model_config=self.vllm_args.model_config, scheduler_config=self.vllm_args.scheduler_config, @@ -178,10 +177,12 @@ async def generate(self, prompt: str, priority: int = 0) -> List[CompletionOutpu request_id = str(self.request_id) # implement from a counter # Wraps prompt into a dict + phil_prompt = prompt prompt: Dict[str, str] = convert_input(prompt) # truncate prmpt tokenization_kwargs = self.tokenization_kwargs or {} + # TODO: add truncation support https://github.com/vllm-project/vllm/issues/4507 truncate_prompt_tokens = self.sampling_params.truncate_prompt_tokens _validate_truncation_size( self.vllm_args.model_config.max_model_len, @@ -201,6 +202,10 @@ async def generate(self, prompt: str, priority: int = 0) -> List[CompletionOutpu priority=priority, data_parallel_rank=None, ) + tokenizer = self.processor.input_preprocessor.get_tokenizer_group() + print("philip1:", request) + print("philip2:", tokenizer.encode("A fake response")) + print("philip3:", tokenizer.encode(phil_prompt + "A fake response")) # Explicitly keeping the redundant logic to make it easier to pick up # vllm changes @@ -268,7 +273,8 @@ async def run_processing(self): for request_output in processed_outputs.request_outputs: if request_output.finished: _, fut = self.requests.pop(request_output.request_id) - fut.set_result(request_output.outputs) + print("philip:", request_output) + fut.set_result(request_output) @endpoint async def update_weights(self): From a95a001ac59a23db290a23ece28fa60041b06001 Mon Sep 17 00:00:00 2001 From: Philip Bontrager Date: Sat, 30 Aug 2025 21:14:13 -0700 Subject: [PATCH 02/31] core updates --- apps/grpo/main.py | 368 ++++++++++++++++--------------------- apps/vllm/main.py | 16 +- src/forge/actors/policy.py | 8 +- 3 files changed, 182 insertions(+), 210 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index e5608a83c..7aede826b 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -24,67 +24,45 @@ from vllm.transformers_utils.tokenizer import get_tokenizer -def compute_sequence_logprobs( - model: torch.nn.Module, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - requires_grad: bool = True, +def compute_logprobs( + logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0 ) -> torch.Tensor: - context_manager = torch.enable_grad() if requires_grad else torch.no_grad() + context_length = logits.shape[1] - input_ids.shape[1] - with context_manager: - outputs = model(input_ids=input_ids, attention_mask=attention_mask) - logits = outputs.logits + # Truncate request logits and drop last + logits = logits[:, context_length - 1 : -1] - # Apply log softmax to get log probabilities - log_probs = torch.log_softmax(logits, dim=-1) + # Compute logprobs + logprobs = torch.log_softmax(logits / temperature, dim=-1) + logprobs = torch.gather(logprobs, 2, input_ids.unsqueeze(-1)).squeeze(-1) - # Extract log probabilities for the actual tokens (excluding the first token for next-token prediction) - shifted_input_ids = input_ids[:, 1:] # Remove first token - shifted_log_probs = log_probs[:, :-1, :] # Remove last logit - - # Gather log probabilities for actual tokens - token_log_probs = torch.gather( - shifted_log_probs, dim=-1, index=shifted_input_ids.unsqueeze(-1) - ).squeeze(-1) - - # Sum log probabilities across sequence (masked by attention) - shifted_attention_mask = attention_mask[:, 1:] - sequence_log_probs = (token_log_probs * shifted_attention_mask).sum(dim=-1) - - return sequence_log_probs - - -Role = Literal[ - "system", # Origin is system prompt - "user", # Origin is user - "assistant", # Origin is the model output - "agent", # Origin is generated - "tool", # Origin is return from a tool call -] - - -@dataclass -class Message: - role: Role - content: str + return logprobs @dataclass class Episode: # TODO: add adtional layer for multi-turn episode_id: str - request: list[Message] + request: str policy_version: int + target: Optinoal[Any] # processed data - response: list[Message] - request_tokens: Optional[torch.Tensor] - response_tokens: Optional[torch.Tensor] + response: Optional[str] + request_tokens: Optional[list[int]] + response_tokens: Optional[list[int]] ref_logprobs: Optional[torch.Tensor] = None reward: Optional[float] = None advantage: Optional[float] = None policy_version: Optional[int] = None + @property + def tokens(self): + return self.request_tokens + self.response_tokens + + @property + def mask(self): + return [0] * len(self.request_tokens) + [1] * len(self.response_tokens) + @dataclass class Group: @@ -93,7 +71,12 @@ class Group: @classmethod def new_group( - cls, group_id: int, group_size: int, request: list[Message], policy_version: int + cls, + group_id: int, + group_size: int, + request: str, + policy_version: int, + target: Any = None, ): episodes = [] for i in range(group_size): @@ -101,32 +84,11 @@ def new_group( episode_id=str(uuid.uuid4()), request=copy.deepcopy(messages), policy_version=policy_version, + target=target, ) return cls(group_id, episodes) -# @dataclass -# class Group: -# response: str # The response text for tokenization -# ref_logprobs: torch.Tensor -# reward: float -# advantage: float = 0.0 -# -# -# class Episode: -# """Episode container for GRPO rollouts.""" -# -# def __init__(self, episode_id: int, prompt: str, target: str, policy_version: int): -# self.episode_id = episode_id -# self.prompt = prompt -# self.target = target -# self.policy_version = policy_version -# self.groups: list[Group] = [] -# -# def add_group(self, group: Group): -# self.groups.append(group) - - class Trainer(ForgeActor): """GRPO Trainer implementation for policy optimization.""" @@ -172,66 +134,56 @@ async def train_step(self, batch: list[Episode]): total_loss = 0.0 num_groups_processed = 0 + # Batch logic -> move to replay replay_buffer + input_ids = [] + advantages = [] + ref_logprobs = [] for episode in batch: - groups = episode.groups - - # Collect all response texts and corresponding data - response_texts = [] - ref_logprobs_list = [] - advantages_list = [] - - for group in groups: - response_texts.append(group.response) - ref_logprobs_list.append(group.ref_logprobs) - advantages_list.append(group.advantage) - - # Tokenize all responses in batch - tokenized = self.tokenizer( - response_texts, - padding=True, - truncation=True, - return_tensors="pt", - max_length=512, # Adjust based on your needs - ) - - input_ids = tokenized["input_ids"].to(self.device) - attention_mask = tokenized["attention_mask"].to(self.device) - - # Compute current policy log probabilities using the model - current_logprobs = compute_sequence_logprobs( - self.model, input_ids, attention_mask, requires_grad=True - ) - - # Convert ref_logprobs and advantages to tensors - ref_logprobs_tensor = torch.stack(ref_logprobs_list).to(self.device) - advantages_tensor = torch.tensor(advantages_list, dtype=torch.float32).to( - self.device - ) - - # Compute GRPO loss components - # Ratio between current policy and reference policy - ratio = torch.exp(current_logprobs - ref_logprobs_tensor) - - # Policy gradient loss weighted by advantages - pg_loss = -torch.mean(ratio * advantages_tensor) - - # KL penalty to prevent policy from deviating too far from reference - kl_penalty = self.beta * torch.mean( - (current_logprobs - ref_logprobs_tensor) ** 2 - ) - - # Total GRPO loss - loss = pg_loss + kl_penalty - total_loss += loss.item() - num_groups_processed += len(groups) - - self.optimizer.zero_grad() - loss.backward() - - # Gradient clipping (optional but recommended for stability) - torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) - - self.optimizer.step() + # collect infomration and batch + pad sequences + input_ids.append(episode.response_tokens + episode.request_tokens) + # TODO philip you are here !!!!!!!!!!!!! + + # loss reference: + # https://github.com/pytorch/torchtune/blob/ + # 67ab86b94de9e7ac7dd9850113ebe69e2bbd307c/torchtune/dev/grpo/loss.py#L123 + # input_ids = tokenized["input_ids"].to(self.device) + # attention_mask = tokenized["attention_mask"].to(self.device) + # + # # Compute current policy log probabilities using the model + # current_logprobs = compute_sequence_logprobs( + # self.model, input_ids, attention_mask, requires_grad=True + # ) + # + # # Convert ref_logprobs and advantages to tensors + # ref_logprobs_tensor = torch.stack(ref_logprobs_list).to(self.device) + # advantages_tensor = torch.tensor(advantages_list, dtype=torch.float32).to( + # self.device + # ) + # + # # Compute GRPO loss components + # # Ratio between current policy and reference policy + # ratio = torch.exp(current_logprobs - ref_logprobs_tensor) + # + # # Policy gradient loss weighted by advantages + # pg_loss = -torch.mean(ratio * advantages_tensor) + # + # # KL penalty to prevent policy from deviating too far from reference + # kl_penalty = self.beta * torch.mean( + # (current_logprobs - ref_logprobs_tensor) ** 2 + # ) + # + # # Total GRPO loss + # loss = pg_loss + kl_penalty + # total_loss += loss.item() + # num_groups_processed += len(groups) + + self.optimizer.zero_grad() + loss.backward() + + # Gradient clipping (optional but recommended for stability) + torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) + + self.optimizer.step() avg_loss = total_loss / len(batch) if batch else 0.0 @@ -274,6 +226,9 @@ def __init__(self, reward_functions: list[Callable]): @endpoint async def evaluate_response(self, prompt: str, response: str, target: str) -> float: + # philip: compare against + # https://github.com/pytorch/torchtune/blob/ + # 67ab86b94de9e7ac7dd9850113ebe69e2bbd307c/torchtune/dev/rl/rewards.py#L270 total_reward = 0.0 for reward_fn in self.reward_functions: reward = reward_fn(prompt, response, target) @@ -286,40 +241,46 @@ class ComputeAdvantages(ForgeActor): def __init__(self, gamma: float = 0.99, lambda_: float = 0.95): super().__init__() - self.gamma = gamma # Discount factor - self.lambda_ = lambda_ # GAE lambda parameter + # self.gamma = gamma # Discount factor + # self.lambda_ = lambda_ # GAE lambda parameter @endpoint - async def compute(self, groups: list[Group]) -> list[float]: - # Extract rewards from groups - rewards = [group.reward for group in groups] - num_groups = len(groups) - - # For simplicity, use reward-to-go as advantages - # This is a valid advantage estimator: A(s,a) = Q(s,a) - V(s) - # where Q(s,a) ≈ reward-to-go and V(s) ≈ average reward - - # Compute discounted reward-to-go for each step - reward_to_go = [] - running_reward = 0.0 - - # Calculate discounted returns (reward-to-go) - for t in reversed(range(num_groups)): - running_reward = rewards[t] + self.gamma * running_reward - reward_to_go.insert(0, running_reward) - - # Compute baseline (mean of rewards) and advantages - baseline = sum(rewards) / len(rewards) if rewards else 0.0 - advantages = [rtg - baseline for rtg in reward_to_go] - - # Normalize advantages to have zero mean and unit variance - if len(advantages) > 1: - mean_adv = sum(advantages) / len(advantages) - var_adv = sum((a - mean_adv) ** 2 for a in advantages) / len(advantages) - std_adv = (var_adv**0.5) if var_adv > 1e-8 else 1.0 - advantages = [(a - mean_adv) / std_adv for a in advantages] + async def compute(self, group: Group) -> list[float]: + # TODO: add batch processing + rewards = torch.Tensor([[e.reward for e in group.episodes]]) + advantages = (rewards - rewards.mean(1, keepdim=True)) / ( + rewards.std(1, keepdim=True) + 1e-4 + ) - return advantages + # # Extract rewards from groups + # rewards = [group.reward for group in groups] + # num_groups = len(groups) + # + # # For simplicity, use reward-to-go as advantages + # # This is a valid advantage estimator: A(s,a) = Q(s,a) - V(s) + # # where Q(s,a) ≈ reward-to-go and V(s) ≈ average reward + # + # # Compute discounted reward-to-go for each step + # reward_to_go = [] + # running_reward = 0.0 + # + # # Calculate discounted returns (reward-to-go) + # for t in reversed(range(num_groups)): + # running_reward = rewards[t] + self.gamma * running_reward + # reward_to_go.insert(0, running_reward) + # + # # Compute baseline (mean of rewards) and advantages + # baseline = sum(rewards) / len(rewards) if rewards else 0.0 + # advantages = [rtg - baseline for rtg in reward_to_go] + # + # # Normalize advantages to have zero mean and unit variance + # if len(advantages) > 1: + # mean_adv = sum(advantages) / len(advantages) + # var_adv = sum((a - mean_adv) ** 2 for a in advantages) / len(advantages) + # std_adv = (var_adv**0.5) if var_adv > 1e-8 else 1.0 + # advantages = [(a - mean_adv) / std_adv for a in advantages] + + return advantages.squeeze(0) class RefModel(ForgeActor): @@ -346,22 +307,22 @@ def __init__(self, model_name, device: torch.device | None = None): self.logger.info(f"Model initialized on {self.device}") @endpoint - async def forward(self, token_ids: list[int]) -> torch.Tensor: - # Use provided token_ids directly - input_ids = ( - torch.tensor(token_ids, dtype=torch.long).unsqueeze(0).to(self.device) - ) - # Create attention mask of all 1s since we have actual tokens (no padding) - attention_mask = torch.ones_like(input_ids).to(self.device) + async def forward(self, request: list[int], response: list[int]) -> torch.Tensor: - # Compute log probabilities using shared utility function - sequence_log_probs = compute_sequence_logprobs( - self.model, input_ids, attention_mask, requires_grad=False - ) + # Convert tokens to tensor + input_ids = torch.tensor( + request + response, dtype=torch.long, device=self.device + ).unsqueeze(0) + + # Compute logits + with torch.inference(): + logits = model(input_ids=input_ids).logits - return ( - sequence_log_probs.squeeze() - ) # Remove batch dimension for single response + # Compute logprobs + input_ids = input_ids[:, len(response) :] + logprobs = compute_logprobs(logits, input_ids) + + return logprobs @dataclass @@ -372,14 +333,27 @@ class DatasetActor(ForgeActor): name: str split: str streaming: bool - transform: Callable + model: str @endpoint def setup(self): + self.tokenizer = get_tokenizer(self.model) + + def gsm8k_transform(sample): + request: str = sample["question"] + formatted_request = self.tokenizer.apply_chat_template( + [{"role": "user", "content": request}], + tokenize=False, + add_generation_prompt=True, + ) + target: str = sample["answer"] + formatted_target = target.split("#### ")[1] + return {"request": formatted_request, "target": formatted_target} + ds = load_dataset( self.path, self.name, split=self.split, streaming=self.streaming ) - ds = ds.map(self.transform) + ds = ds.map(gsm8k_transform) ds = ds.shuffle() self._iterator = iter(ds) @@ -396,14 +370,6 @@ async def main(): group_size = 1 model = "Qwen/Qwen3-1.7B" - # ---- Setup data transform ---- # - def gsm8k_to_messages(sample): - question = content = sample["question"] - full_answer: str = sample["answer"] - answer = full_answer.split("#### ")[1] - return - return [Message("user", question), Message("assistant", answer)] - # ---- Setup WandB Logger ---- # logger = get_metric_logger( "wandb", @@ -451,7 +417,7 @@ def gsm8k_to_messages(sample): "main", split="train", streaming=True, - transform=gsm8k_to_messages, + model=model, ) compute_advantages = await spawn_service( @@ -475,8 +441,6 @@ def gsm8k_to_messages(sample): ) print("All services initialized successfully!") - tokenizer = get_tokenizer(model) - print("philip5:", tokenizer.encode("A fake response")) # ---- Core RL loops ---- # async def continuous_rollouts(): @@ -488,41 +452,33 @@ async def continuous_rollouts(): if sample is None: print("Dataloader is empty, exiting continuous rollout") return - prompt, target = sample + prompt, target = sample["request"], sample["target"] version = 0 # await policy.get_current_version.choose() group = Group.new_group( group_id=rollout_count, group_size=group_size, - request=sample, + request=prompt, policy_version=version, + target=target, ) responses = await policy.generate.choose(prompt) for episode, response in zip(group.episodes, responses.outputs): - - episode.tokens = response.prompt_token_ids + response.token_ids - ref_logprobs = await ref_model.forward.choose(episode.tokens) - reward = await reward_actor.evaluate_response.choose( - prompt=prompt, response=action.text, target=target + episode.request_tokens = responses.prompt_token_ids + episode.response_tokens = response.token_ids + episode.ref_logprobs = await ref_model.forward.choose( + request=episode.request_tokens, response=episode.response_tokens ) - episode.add_group( - Group( - response=action.text, - ref_logprobs=ref_logprobs, - reward=reward, - ) + episode.reward = await reward_actor.evaluate_response.choose( + prompt=prompt, response=response.text, target=target ) - - advantages = await compute_advantages.compute.choose(episode.groups) - for advantage, group in zip(advantages, episode.groups): - group.advantage = advantage - - await replay_buffer.add.choose(episode) + advantages = await compute_advantages.compute.choose(group) + for episode, advantage in zip(group.episodes, advantages): + episode.advantage = advantage + await replay_buffer.add.choose(episode) rollout_count += 1 if rollout_count % 10 == 0: - avg_reward = sum(group.reward for group in episode.groups) / len( - episode.groups - ) + avg_reward = sum(e.reward for e in group.episodes) / len(group.episodes) print( f"Generated {rollout_count} rollouts w/ average reward {avg_reward}" ) diff --git a/apps/vllm/main.py b/apps/vllm/main.py index b1df7c5eb..b9f702cf3 100644 --- a/apps/vllm/main.py +++ b/apps/vllm/main.py @@ -20,6 +20,13 @@ from forge.controller.spawn import spawn_service from vllm.outputs import CompletionOutput, RequestOutput +# philip +from vllm.transformers_utils.tokenizer import get_tokenizer + +# convert to messages +# 2 versions: formatted prompt, formatted full sequence +# - remove eod if vllm didn't finish sequence + async def main(): """Main application for running vLLM policy inference.""" @@ -34,6 +41,13 @@ async def main(): else: prompt = args.prompt + # philip: format prompt + tokenizer = get_tokenizer(policy_config.worker_params.model) + messages = [{"role": "user", "content": prompt}] + prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + # Run the policy await run_vllm(service_config, policy_config, prompt) @@ -43,7 +57,7 @@ def parse_args() -> Namespace: parser.add_argument( "--model", type=str, - default="meta-llama/Llama-3.1-8B-Instruct", + default="Qwen/Qwen3-1.7B", # "meta-llama/Llama-3.1-8B-Instruct", help="Model to use", ) parser.add_argument( diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 26b3af2f6..9d895131a 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -178,7 +178,9 @@ async def generate(self, prompt: str, priority: int = 0) -> List[CompletionOutpu # Wraps prompt into a dict phil_prompt = prompt - prompt: Dict[str, str] = convert_input(prompt) + prompt: Dict[str, str] = convert_input( + prompt_token_ids=prompt + ) # philip remove key # truncate prmpt tokenization_kwargs = self.tokenization_kwargs or {} @@ -204,8 +206,8 @@ async def generate(self, prompt: str, priority: int = 0) -> List[CompletionOutpu ) tokenizer = self.processor.input_preprocessor.get_tokenizer_group() print("philip1:", request) - print("philip2:", tokenizer.encode("A fake response")) - print("philip3:", tokenizer.encode(phil_prompt + "A fake response")) + # print("philip2:", tokenizer.encode("A fake response")) + # print("philip3:", tokenizer.encode(phil_prompt + "A fake response")) # Explicitly keeping the redundant logic to make it easier to pick up # vllm changes From 3ba0df6a0168434b87a06962b2da400050fac6a9 Mon Sep 17 00:00:00 2001 From: Philip Bontrager Date: Mon, 1 Sep 2025 14:37:35 -0700 Subject: [PATCH 03/31] batch update --- apps/grpo/main.py | 240 ++++++++++++++++++------------------- apps/vllm/main.py | 9 +- src/forge/actors/policy.py | 10 +- 3 files changed, 116 insertions(+), 143 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 7aede826b..a8838e241 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -9,7 +9,7 @@ import time import uuid from dataclasses import dataclass -from typing import Callable +from typing import Any, Callable, Optional import torch from datasets import load_dataset @@ -20,7 +20,8 @@ from forge.data.rewards import MathReward, ThinkingReward from forge.util.metric_logging import get_metric_logger from monarch.actor import endpoint -from transformers import AutoModelForCausalLM, AutoTokenizer +from torch import nn +from transformers import AutoModelForCausalLM from vllm.transformers_utils.tokenizer import get_tokenizer @@ -39,29 +40,59 @@ def compute_logprobs( return logprobs +class SimpleGRPOLoss(nn.Module): + """Simplified GRPO Loss for simplified single step updates""" + + def __init__(self, epsilon=0.1, beta=0.1): + super().__init__() + self.epsilon = epsilon + self.beta = beta + + def forward(self, logprobs, ref_logprobs, advantages, padding_mask): + log_ratio = ref_logprobs.detach() - logprobs + kl = torch.exp(log_ratio) - log_ratio - 1 + + pl = torch.exp(logprobs - logprobs.detach()) * advantages + loss = -pl + self.beta * kl + + # Compute mean + loss = (loss * padding_mask).sum() / (padding_mask.sum() + 1e-8) + return loss + + @dataclass class Episode: # TODO: add adtional layer for multi-turn episode_id: str request: str policy_version: int - target: Optinoal[Any] + pad_id: int + request_len: int + response_len: int + target: Optional[Any] = None # processed data - response: Optional[str] - request_tokens: Optional[list[int]] - response_tokens: Optional[list[int]] + response: Optional[str] = None + request_tokens: Optional[list[int]] = None + response_tokens: Optional[list[int]] = None ref_logprobs: Optional[torch.Tensor] = None reward: Optional[float] = None advantage: Optional[float] = None - policy_version: Optional[int] = None @property - def tokens(self): - return self.request_tokens + self.response_tokens + def request_tensor(self): + tensor = torch.tensor(self.request_tokens, dtype=torch.long) + if tensor.shape[0] < self.request_len: # left pad + diff = self.request_len - tensor.shape[0] + tensor = F.pad(tensor, (diff, 0), value=self.pad_id) + return tensor @property - def mask(self): - return [0] * len(self.request_tokens) + [1] * len(self.response_tokens) + def response_tensor(self): + tensor = torch.tensor(self.response_tokens, dtype=torch.long) + if tensor.shape[0] < self.response_len: # right pad + diff = self.request_len - tensor.shape[0] + tensor = F.pad(tensor, (0, diff), value=self.pad_id) + return tensor @dataclass @@ -76,6 +107,9 @@ def new_group( group_size: int, request: str, policy_version: int, + pad_id: int, + request_len: int, + response_len: int, target: Any = None, ): episodes = [] @@ -84,33 +118,30 @@ def new_group( episode_id=str(uuid.uuid4()), request=copy.deepcopy(messages), policy_version=policy_version, + pad_id=pad_iddd, + request_len=request_len, + response_len=response_len, target=target, ) return cls(group_id, episodes) +@dataclass class Trainer(ForgeActor): """GRPO Trainer implementation for policy optimization.""" - def __init__( - self, - learning_rate: float = 1e-5, - beta: float = 0.1, - model_name: str = "", - device: torch.device | None = None, - ): - super().__init__() - self.learning_rate = learning_rate - self.beta = beta # KL penalty coefficient - self.model_name = model_name + model_name: str + learning_rate: float = 1e-5 + beta: float = 0.1 + epsilon: float = 0.1 + device: torch.device | None = None + def setup(self): # Set device - if device is None: + if self.device is None: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - else: - self.device = device - # Initialize model and tokenizer + # Initialize model self.model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, @@ -118,64 +149,47 @@ def __init__( ).to(self.device) self.model.train() - self.tokenizer = AutoTokenizer.from_pretrained(model_name) - if self.tokenizer.pad_token is None: - self.tokenizer.pad_token = self.tokenizer.eos_token - # Initialize optimizer self.optimizer = torch.optim.AdamW( self.model.parameters(), lr=self.learning_rate ) + self.optimizer.zero_grad() + + # Initialize loss + self.loss = SimpleGRPOLoss(self.epsilon, self.beta) self.logger.info(f"Model initialized on {self.device}") @endpoint async def train_step(self, batch: list[Episode]): total_loss = 0.0 - num_groups_processed = 0 - - # Batch logic -> move to replay replay_buffer - input_ids = [] - advantages = [] - ref_logprobs = [] - for episode in batch: - # collect infomration and batch + pad sequences - input_ids.append(episode.response_tokens + episode.request_tokens) - # TODO philip you are here !!!!!!!!!!!!! - - # loss reference: - # https://github.com/pytorch/torchtune/blob/ - # 67ab86b94de9e7ac7dd9850113ebe69e2bbd307c/torchtune/dev/grpo/loss.py#L123 - # input_ids = tokenized["input_ids"].to(self.device) - # attention_mask = tokenized["attention_mask"].to(self.device) - # - # # Compute current policy log probabilities using the model - # current_logprobs = compute_sequence_logprobs( - # self.model, input_ids, attention_mask, requires_grad=True - # ) - # - # # Convert ref_logprobs and advantages to tensors - # ref_logprobs_tensor = torch.stack(ref_logprobs_list).to(self.device) - # advantages_tensor = torch.tensor(advantages_list, dtype=torch.float32).to( - # self.device - # ) - # - # # Compute GRPO loss components - # # Ratio between current policy and reference policy - # ratio = torch.exp(current_logprobs - ref_logprobs_tensor) - # - # # Policy gradient loss weighted by advantages - # pg_loss = -torch.mean(ratio * advantages_tensor) - # - # # KL penalty to prevent policy from deviating too far from reference - # kl_penalty = self.beta * torch.mean( - # (current_logprobs - ref_logprobs_tensor) ** 2 - # ) - # - # # Total GRPO loss - # loss = pg_loss + kl_penalty - # total_loss += loss.item() - # num_groups_processed += len(groups) + num_episodes_processed = 0 + pad_id = batch[0].pad_id + + # prepare batch + request = [e.response_tokens for e in batch] + request = torch.stack(request).to(self.device) + + response = [e.response_tokens for e in batch] + response = torch.stack(response).to(self.device) + + ref_logprobs = [e.ref_logprobs for e in batch] + ref_logprobs = torch.stack(ref_logprobs).to(self.device) + + advantages = [e.advantages for e in batch] + advantages = torch.tensor(advantages).to(self.device).unsqueeze(-1) + del batch + + # compute policy logprobs + input_ids = torch.cat([request, response]) + mask = input_ids[1] != pad_id + logits = self.model(input_ids=input_ids, attention_mask=mask).logits + logprobs = compute_logprobs(logits, response) + del logits + + # compute loss + mask = (response != pad_id).unsqueeze(-1) + loss = self.loss(logprobs, ref_logprobs, advantages, mask) self.optimizer.zero_grad() loss.backward() @@ -185,9 +199,10 @@ async def train_step(self, batch: list[Episode]): self.optimizer.step() + total_loss += loss.item() avg_loss = total_loss / len(batch) if batch else 0.0 - return {"loss": avg_loss, "groups_processed": num_groups_processed} + return {"loss": avg_loss, "episodes_processed": num_episodes_processed} @endpoint async def update_weights(self, policy_actor): @@ -226,9 +241,6 @@ def __init__(self, reward_functions: list[Callable]): @endpoint async def evaluate_response(self, prompt: str, response: str, target: str) -> float: - # philip: compare against - # https://github.com/pytorch/torchtune/blob/ - # 67ab86b94de9e7ac7dd9850113ebe69e2bbd307c/torchtune/dev/rl/rewards.py#L270 total_reward = 0.0 for reward_fn in self.reward_functions: reward = reward_fn(prompt, response, target) @@ -239,47 +251,13 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl class ComputeAdvantages(ForgeActor): """Compute advantages for GRPO using reward signals.""" - def __init__(self, gamma: float = 0.99, lambda_: float = 0.95): - super().__init__() - # self.gamma = gamma # Discount factor - # self.lambda_ = lambda_ # GAE lambda parameter - @endpoint async def compute(self, group: Group) -> list[float]: # TODO: add batch processing rewards = torch.Tensor([[e.reward for e in group.episodes]]) - advantages = (rewards - rewards.mean(1, keepdim=True)) / ( + advantages = (rewards - rewards.me / an(1, keepdim=True)) / ( rewards.std(1, keepdim=True) + 1e-4 ) - - # # Extract rewards from groups - # rewards = [group.reward for group in groups] - # num_groups = len(groups) - # - # # For simplicity, use reward-to-go as advantages - # # This is a valid advantage estimator: A(s,a) = Q(s,a) - V(s) - # # where Q(s,a) ≈ reward-to-go and V(s) ≈ average reward - # - # # Compute discounted reward-to-go for each step - # reward_to_go = [] - # running_reward = 0.0 - # - # # Calculate discounted returns (reward-to-go) - # for t in reversed(range(num_groups)): - # running_reward = rewards[t] + self.gamma * running_reward - # reward_to_go.insert(0, running_reward) - # - # # Compute baseline (mean of rewards) and advantages - # baseline = sum(rewards) / len(rewards) if rewards else 0.0 - # advantages = [rtg - baseline for rtg in reward_to_go] - # - # # Normalize advantages to have zero mean and unit variance - # if len(advantages) > 1: - # mean_adv = sum(advantages) / len(advantages) - # var_adv = sum((a - mean_adv) ** 2 for a in advantages) / len(advantages) - # std_adv = (var_adv**0.5) if var_adv > 1e-8 else 1.0 - # advantages = [(a - mean_adv) / std_adv for a in advantages] - return advantages.squeeze(0) @@ -307,19 +285,19 @@ def __init__(self, model_name, device: torch.device | None = None): self.logger.info(f"Model initialized on {self.device}") @endpoint - async def forward(self, request: list[int], response: list[int]) -> torch.Tensor: + async def forward(self, episode: Episode) -> torch.Tensor: # Convert tokens to tensor - input_ids = torch.tensor( - request + response, dtype=torch.long, device=self.device - ).unsqueeze(0) + req, res = episode.request_tensor, episode.response_tensor + input_ids = torch.cat([request, response]).to(self.device).unsqueeze(0) + mask = input_ids[1] != episode.pad_id # Compute logits with torch.inference(): - logits = model(input_ids=input_ids).logits + logits = model(input_ids=input_ids, attention_mask=mask).logits # Compute logprobs - input_ids = input_ids[:, len(response) :] + input_ids = input_ids[:, request.shape[1] :] logprobs = compute_logprobs(logits, input_ids) return logprobs @@ -331,7 +309,7 @@ class DatasetActor(ForgeActor): path: str name: str - split: str + data_split: str streaming: bool model: str @@ -351,7 +329,7 @@ def gsm8k_transform(sample): return {"request": formatted_request, "target": formatted_target} ds = load_dataset( - self.path, self.name, split=self.split, streaming=self.streaming + self.path, self.name, split=self.data_split, streaming=self.streaming ) ds = ds.map(gsm8k_transform) ds = ds.shuffle() @@ -364,11 +342,19 @@ async def sample(self) -> dict[str, str] | None: except StopIteration: return None + @endpoint + def pad_token(self): + if self.tokenizer.pad_token is None: + return self.tokenizer.eos_token + return self.tokenizer.pad_token + async def main(): """Main GRPO training loop with rollout and training processes.""" group_size = 1 model = "Qwen/Qwen3-1.7B" + max_req_tokens = 512 + max_res_tokens = 128 # ---- Setup WandB Logger ---- # logger = get_metric_logger( @@ -389,7 +375,7 @@ async def main(): PolicyConfig( num_workers=1, worker_params=WorkerConfig(model=model), - sampling_params=SamplingOverrides(n=group_size, max_tokens=16), + sampling_params=SamplingOverrides(n=group_size, max_tokens=max_res_tokens), available_devices="3", ), ) @@ -415,7 +401,7 @@ async def main(): DatasetActor, "openai/gsm8k", "main", - split="train", + data_split="train", streaming=True, model=model, ) @@ -445,6 +431,7 @@ async def main(): # ---- Core RL loops ---- # async def continuous_rollouts(): rollout_count = 0 + pad_id = dataloader.pad_token.choose() # TODO: Move this into setup asyncio.create_task(policy.run_processing.call()) while True: @@ -459,15 +446,16 @@ async def continuous_rollouts(): group_size=group_size, request=prompt, policy_version=version, + pad_id=pad_id, + request_len=max_req_tokens, + response_len=max_res_tokens, target=target, ) responses = await policy.generate.choose(prompt) for episode, response in zip(group.episodes, responses.outputs): episode.request_tokens = responses.prompt_token_ids episode.response_tokens = response.token_ids - episode.ref_logprobs = await ref_model.forward.choose( - request=episode.request_tokens, response=episode.response_tokens - ) + episode.ref_logprobs = await ref_model.forward.choose(episode) episode.reward = await reward_actor.evaluate_response.choose( prompt=prompt, response=response.text, target=target ) diff --git a/apps/vllm/main.py b/apps/vllm/main.py index b9f702cf3..a20e9e584 100644 --- a/apps/vllm/main.py +++ b/apps/vllm/main.py @@ -20,13 +20,6 @@ from forge.controller.spawn import spawn_service from vllm.outputs import CompletionOutput, RequestOutput -# philip -from vllm.transformers_utils.tokenizer import get_tokenizer - -# convert to messages -# 2 versions: formatted prompt, formatted full sequence -# - remove eod if vllm didn't finish sequence - async def main(): """Main application for running vLLM policy inference.""" @@ -41,7 +34,7 @@ async def main(): else: prompt = args.prompt - # philip: format prompt + # format prompt tokenizer = get_tokenizer(policy_config.worker_params.model) messages = [{"role": "user", "content": prompt}] prompt = tokenizer.apply_chat_template( diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 9d895131a..c85b093c6 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -130,7 +130,6 @@ async def setup(self): # Setup processors # TODO: move all processing to the Environment # TODO: add support for `log_stats` and `mm_registry` - print("philip0:", self.vllm_args.model_config) tokenizer = init_tokenizer_from_configs( model_config=self.vllm_args.model_config, scheduler_config=self.vllm_args.scheduler_config, @@ -177,10 +176,7 @@ async def generate(self, prompt: str, priority: int = 0) -> List[CompletionOutpu request_id = str(self.request_id) # implement from a counter # Wraps prompt into a dict - phil_prompt = prompt - prompt: Dict[str, str] = convert_input( - prompt_token_ids=prompt - ) # philip remove key + prompt: Dict[str, str] = convert_input(prompt_token_ids=prompt) # truncate prmpt tokenization_kwargs = self.tokenization_kwargs or {} @@ -205,9 +201,6 @@ async def generate(self, prompt: str, priority: int = 0) -> List[CompletionOutpu data_parallel_rank=None, ) tokenizer = self.processor.input_preprocessor.get_tokenizer_group() - print("philip1:", request) - # print("philip2:", tokenizer.encode("A fake response")) - # print("philip3:", tokenizer.encode(phil_prompt + "A fake response")) # Explicitly keeping the redundant logic to make it easier to pick up # vllm changes @@ -275,7 +268,6 @@ async def run_processing(self): for request_output in processed_outputs.request_outputs: if request_output.finished: _, fut = self.requests.pop(request_output.request_id) - print("philip:", request_output) fut.set_result(request_output) @endpoint From 3e32264f62c4f010ff8b65a50af99c0431d15fd8 Mon Sep 17 00:00:00 2001 From: Philip Bontrager Date: Mon, 1 Sep 2025 18:24:50 -0700 Subject: [PATCH 04/31] fix typo --- apps/grpo/main.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index a8838e241..7b464db31 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -136,6 +136,7 @@ class Trainer(ForgeActor): epsilon: float = 0.1 device: torch.device | None = None + @endpoint def setup(self): # Set device if self.device is None: @@ -255,7 +256,7 @@ class ComputeAdvantages(ForgeActor): async def compute(self, group: Group) -> list[float]: # TODO: add batch processing rewards = torch.Tensor([[e.reward for e in group.episodes]]) - advantages = (rewards - rewards.me / an(1, keepdim=True)) / ( + advantages = (rewards - rewards.mean(1, keepdim=True)) / ( rewards.std(1, keepdim=True) + 1e-4 ) return advantages.squeeze(0) From 52028a5bbc41e6adee40061b10e74da7de264b9f Mon Sep 17 00:00:00 2001 From: Philip Bontrager Date: Mon, 1 Sep 2025 21:05:07 -0700 Subject: [PATCH 05/31] missing import --- apps/vllm/main.py | 1 + 1 file changed, 1 insertion(+) diff --git a/apps/vllm/main.py b/apps/vllm/main.py index dae08ecea..dd3b8eae3 100644 --- a/apps/vllm/main.py +++ b/apps/vllm/main.py @@ -18,6 +18,7 @@ from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig from forge.controller.service import ServiceConfig, shutdown_service, spawn_service from vllm.outputs import CompletionOutput, RequestOutput +from vllm.transformers_utils.tokenizer import get_tokenizer async def main(): From e2a3a6894f53b99eb8f80187a656fa71ffd1ec8a Mon Sep 17 00:00:00 2001 From: Philip Bontrager Date: Tue, 2 Sep 2025 14:45:36 -0700 Subject: [PATCH 06/31] debug merge --- apps/grpo/main.py | 29 +++++++++++++++-------------- src/forge/actors/policy.py | 2 +- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index e3d514083..1737ee745 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. import asyncio -import copy import logging import time import uuid @@ -118,14 +117,16 @@ def new_group( ): episodes = [] for i in range(group_size): - Episode( - episode_id=str(uuid.uuid4()), - request=copy.deepcopy(messages), - policy_version=policy_version, - pad_id=pad_iddd, - request_len=request_len, - response_len=response_len, - target=target, + episodes.append( + Episode( + episode_id=str(uuid.uuid4()), + request=request, + policy_version=policy_version, + pad_id=pad_id, + request_len=request_len, + response_len=response_len, + target=target, + ) ) return cls(group_id, episodes) @@ -148,7 +149,7 @@ def setup(self): # Initialize model self.model = AutoModelForCausalLM.from_pretrained( - model_name, + self.model_name, torch_dtype=torch.bfloat16, trust_remote_code=True, ).to(self.device) @@ -313,7 +314,7 @@ class DatasetActor(ForgeActor): """Actor wrapper for HuggingFace dataset to provide async interface.""" path: str - name: str + revision: str data_split: str streaming: bool model: str @@ -334,7 +335,7 @@ def gsm8k_transform(sample): return {"request": formatted_request, "target": formatted_target} ds = load_dataset( - self.path, self.name, split=self.data_split, streaming=self.streaming + self.path, self.revision, split=self.data_split, streaming=self.streaming ) ds = ds.map(gsm8k_transform) ds = ds.shuffle() @@ -382,7 +383,7 @@ async def main(): ServiceConfig(procs_per_replica=1, num_replicas=1), DatasetActor, path="openai/gsm8k", - name="main", + revision="main", data_split="train", streaming=True, model=model, @@ -416,7 +417,7 @@ async def main(): spawn_service( ServiceConfig(procs_per_replica=1, num_replicas=1, with_gpus=True), RefModel, - model=titan_model, + model_name=model, ), spawn_service( ServiceConfig(procs_per_replica=1, num_replicas=1), diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 9c9a40efa..9cfb47b46 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -226,7 +226,7 @@ async def generate(self, prompt: str, priority: int = 0) -> List[CompletionOutpu request_id = str(self.request_id) # implement from a counter # Wraps prompt into a dict - prompt: Dict[str, str] = convert_input(prompt_token_ids=prompt) + prompt: Dict[str, str] = convert_input(prompt=prompt) # truncate prmpt tokenization_kwargs = self.tokenization_kwargs or {} From 2cf9d00a4cc5325f5e02bdcf9597bf29630e2bb5 Mon Sep 17 00:00:00 2001 From: Philip Bontrager Date: Thu, 4 Sep 2025 07:58:13 -0700 Subject: [PATCH 07/31] more fixes --- apps/grpo/main.py | 52 +++++++++++++++++++++++++++-------------------- 1 file changed, 30 insertions(+), 22 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 1737ee745..cd15c643f 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -12,6 +12,7 @@ from typing import Any, Callable, Optional import torch +import torch.nn.functional as F from datasets import load_dataset from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig from forge.actors.replay_buffer import ReplayBuffer @@ -173,23 +174,28 @@ async def train_step(self, batch: list[Episode]): pad_id = batch[0].pad_id # prepare batch - request = [e.response_tokens for e in batch] - request = torch.stack(request).to(self.device) + request = [e.response_tensor for e in batch] + request = torch.stack(request).to(self.device) # [b x s] + print("phil1", request.shape) - response = [e.response_tokens for e in batch] - response = torch.stack(response).to(self.device) + response = [e.response_tensor for e in batch] + response = torch.stack(response).to(self.device) # [b x s] + print("phil2", response.shape) ref_logprobs = [e.ref_logprobs for e in batch] - ref_logprobs = torch.stack(ref_logprobs).to(self.device) + ref_logprobs = torch.stack(ref_logprobs).to(self.device) # [b x s x d] + print("phil3", ref_logprobs.shape) - advantages = [e.advantages for e in batch] - advantages = torch.tensor(advantages).to(self.device).unsqueeze(-1) + advantages = [e.advantage for e in batch] + advantages = torch.tensor(advantages).to(self.device).unsqueeze(-1) # [b x 1] + print("phil4", advantages) del batch # compute policy logprobs input_ids = torch.cat([request, response]) - mask = input_ids[1] != pad_id + mask = input_ids != pad_id logits = self.model(input_ids=input_ids, attention_mask=mask).logits + print("phil5", logits.shape) logprobs = compute_logprobs(logits, response) del logits @@ -238,12 +244,10 @@ async def update_weights(self, policy_actor): self.logger.info(f"Updating weights took {end_time - start_time:.2f} seconds") +@dataclass class RewardActor(ForgeActor): """Reward actor that uses a list of scoring functions.""" - - def __init__(self, reward_functions: list[Callable]): - super().__init__() - self.reward_functions = reward_functions + reward_functions: list[Callable] @endpoint async def evaluate_response(self, prompt: str, response: str, target: str) -> float: @@ -251,6 +255,7 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl for reward_fn in self.reward_functions: reward = reward_fn(prompt, response, target) total_reward += reward + print("phil_rew_0", total_reward) return total_reward @@ -261,9 +266,11 @@ class ComputeAdvantages(ForgeActor): async def compute(self, group: Group) -> list[float]: # TODO: add batch processing rewards = torch.Tensor([[e.reward for e in group.episodes]]) + print("phil", rewards) advantages = (rewards - rewards.mean(1, keepdim=True)) / ( rewards.std(1, keepdim=True) + 1e-4 ) + print("phil-1", advantages.squeeze(0)) return advantages.squeeze(0) @@ -295,16 +302,19 @@ async def forward(self, episode: Episode) -> torch.Tensor: # Convert tokens to tensor req, res = episode.request_tensor, episode.response_tensor - input_ids = torch.cat([request, response]).to(self.device).unsqueeze(0) - mask = input_ids[1] != episode.pad_id + input_ids = torch.cat([req, res]).to(self.device).unsqueeze(0) + mask = input_ids != episode.pad_id # Compute logits - with torch.inference(): - logits = model(input_ids=input_ids, attention_mask=mask).logits + with torch.inference_mode(): + logits = self.model(input_ids=input_ids, attention_mask=mask).logits # Compute logprobs - input_ids = input_ids[:, request.shape[1] :] + input_ids = input_ids[:, len(req) :] + print("phil_ref_0", input_ids.shape) + print("phil_ref_1", logits.shape) logprobs = compute_logprobs(logits, input_ids) + print("phil_ref_2", logprobs.shape) return logprobs @@ -349,10 +359,8 @@ async def sample(self) -> dict[str, str] | None: return None @endpoint - def pad_token(self): - if self.tokenizer.pad_token is None: - return self.tokenizer.eos_token - return self.tokenizer.pad_token + async def pad_token(self): + return self.tokenizer.pad_token_id async def main(): @@ -431,7 +439,7 @@ async def main(): # ---- Core RL loops ---- # async def continuous_rollouts(): rollout_count = 0 - pad_id = dataloader.pad_token.choose() + pad_id = await dataloader.pad_token.choose() while True: sample = await dataloader.sample.choose() if sample is None: From b85320cc080c8658a8c7445d114ef9a649fe07a0 Mon Sep 17 00:00:00 2001 From: joecummings Date: Thu, 4 Sep 2025 08:05:31 -0700 Subject: [PATCH 08/31] Remove dtype warnings --- apps/grpo/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index cd15c643f..38d48c857 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -151,7 +151,7 @@ def setup(self): # Initialize model self.model = AutoModelForCausalLM.from_pretrained( self.model_name, - torch_dtype=torch.bfloat16, + dtype=torch.bfloat16, trust_remote_code=True, ).to(self.device) self.model.train() @@ -288,7 +288,7 @@ def __init__(self, model_name, device: torch.device | None = None): # Initialize model and tokenizer self.model = AutoModelForCausalLM.from_pretrained( model_name, - torch_dtype=torch.bfloat16, + dtype=torch.bfloat16, trust_remote_code=True, ).to(self.device) From f7626ce0da60651e338ad01c839ee0ab401bb238 Mon Sep 17 00:00:00 2001 From: joecummings Date: Thu, 4 Sep 2025 10:50:04 -0700 Subject: [PATCH 09/31] Stub --- apps/grpo/main.py | 102 +++++++++++++++++++++++++++++----------------- 1 file changed, 64 insertions(+), 38 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 38d48c857..52bfeca46 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -53,15 +53,33 @@ def __init__(self, epsilon=0.1, beta=0.1): self.beta = beta def forward(self, logprobs, ref_logprobs, advantages, padding_mask): - log_ratio = ref_logprobs.detach() - logprobs - kl = torch.exp(log_ratio) - log_ratio - 1 + per_token_kl = ( + torch.exp(ref_logprobs.detach() - logprobs) + - (ref_logprobs.detach() - logprobs) + - 1 + ) + + advantages = advantages[:, None] # [B x G, 1] + + per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages + per_token_loss = -(per_token_policy_loss - self.beta * per_token_kl) - pl = torch.exp(logprobs - logprobs.detach()) * advantages - loss = -pl + self.beta * kl + loss = (per_token_loss * padding_mask).sum(dim=1) / ( + padding_mask.sum(dim=1) + 1e-8 + ).mean() - # Compute mean - loss = (loss * padding_mask).sum() / (padding_mask.sum() + 1e-8) return loss + # log_ratio = ref_logprobs.detach() - logprobs + # kl = torch.exp(log_ratio) - log_ratio - 1 + + # pl = torch.exp(logprobs - logprobs.detach()) * advantages + # loss = -pl + self.beta * kl + + # print(loss.shape, padding_mask.shape) + + # # Compute mean + # loss = (loss * padding_mask).sum() / (padding_mask.sum() + 1e-8) + # return loss @dataclass @@ -94,7 +112,7 @@ def request_tensor(self): def response_tensor(self): tensor = torch.tensor(self.response_tokens, dtype=torch.long) if tensor.shape[0] < self.response_len: # right pad - diff = self.request_len - tensor.shape[0] + diff = self.response_len - tensor.shape[0] tensor = F.pad(tensor, (0, diff), value=self.pad_id) return tensor @@ -129,7 +147,7 @@ def new_group( target=target, ) ) - return cls(group_id, episodes) + return cls(str(group_id), episodes) @dataclass @@ -172,35 +190,31 @@ async def train_step(self, batch: list[Episode]): total_loss = 0.0 num_episodes_processed = 0 pad_id = batch[0].pad_id + bsz = len(batch) # prepare batch - request = [e.response_tensor for e in batch] - request = torch.stack(request).to(self.device) # [b x s] - print("phil1", request.shape) + request = [e.request_tensor for e in batch] + request = torch.stack(request).to(self.device) # [b x s] response = [e.response_tensor for e in batch] - response = torch.stack(response).to(self.device) # [b x s] - print("phil2", response.shape) + response = torch.stack(response).to(self.device) # [b x s] ref_logprobs = [e.ref_logprobs for e in batch] - ref_logprobs = torch.stack(ref_logprobs).to(self.device) # [b x s x d] - print("phil3", ref_logprobs.shape) + ref_logprobs = torch.stack(ref_logprobs).to(self.device) # [b x s x d] advantages = [e.advantage for e in batch] - advantages = torch.tensor(advantages).to(self.device).unsqueeze(-1) # [b x 1] - print("phil4", advantages) + advantages = torch.tensor(advantages).to(self.device).unsqueeze(-1) # [b x 1] del batch # compute policy logprobs - input_ids = torch.cat([request, response]) + input_ids = torch.cat([request, response], dim=1) mask = input_ids != pad_id logits = self.model(input_ids=input_ids, attention_mask=mask).logits - print("phil5", logits.shape) logprobs = compute_logprobs(logits, response) del logits # compute loss - mask = (response != pad_id).unsqueeze(-1) + mask = response != pad_id loss = self.loss(logprobs, ref_logprobs, advantages, mask) self.optimizer.zero_grad() @@ -212,7 +226,7 @@ async def train_step(self, batch: list[Episode]): self.optimizer.step() total_loss += loss.item() - avg_loss = total_loss / len(batch) if batch else 0.0 + avg_loss = total_loss / bsz return {"loss": avg_loss, "episodes_processed": num_episodes_processed} @@ -247,6 +261,7 @@ async def update_weights(self, policy_actor): @dataclass class RewardActor(ForgeActor): """Reward actor that uses a list of scoring functions.""" + reward_functions: list[Callable] @endpoint @@ -266,12 +281,21 @@ class ComputeAdvantages(ForgeActor): async def compute(self, group: Group) -> list[float]: # TODO: add batch processing rewards = torch.Tensor([[e.reward for e in group.episodes]]) - print("phil", rewards) - advantages = (rewards - rewards.mean(1, keepdim=True)) / ( - rewards.std(1, keepdim=True) + 1e-4 - ) - print("phil-1", advantages.squeeze(0)) - return advantages.squeeze(0) + print("rewards", rewards) + mean = rewards.mean(1, keepdim=True) + std = rewards.std(1, keepdim=True) + + print(std) + + # if std is nan, return 0s. Remove this before shipping + if std.isnan().any(): + advantages = torch.zeros_like(rewards) + else: + advantages = (rewards - mean) / (std + 1e-4) + + x = advantages.squeeze(0).tolist() + print("phil_adv_0", x) + return x class RefModel(ForgeActor): @@ -279,27 +303,22 @@ def __init__(self, model_name, device: torch.device | None = None): super().__init__() self.model_name = model_name - # Set device if device is None: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: self.device = device - # Initialize model and tokenizer self.model = AutoModelForCausalLM.from_pretrained( model_name, dtype=torch.bfloat16, trust_remote_code=True, ).to(self.device) - - # Set model to eval mode for reference computations self.model.eval() self.logger.info(f"Model initialized on {self.device}") @endpoint async def forward(self, episode: Episode) -> torch.Tensor: - # Convert tokens to tensor req, res = episode.request_tensor, episode.response_tensor input_ids = torch.cat([req, res]).to(self.device).unsqueeze(0) @@ -335,14 +354,14 @@ def setup(self): def gsm8k_transform(sample): request: str = sample["question"] - formatted_request = self.tokenizer.apply_chat_template( - [{"role": "user", "content": request}], - tokenize=False, - add_generation_prompt=True, - ) + # formatted_request = self.tokenizer.apply_chat_template( + # [{"role": "user", "content": request}], + # tokenize=False, + # add_generation_prompt=True, + # ) target: str = sample["answer"] formatted_target = target.split("#### ")[1] - return {"request": formatted_request, "target": formatted_target} + return {"request": request, "target": formatted_target} ds = load_dataset( self.path, self.revision, split=self.data_split, streaming=self.streaming @@ -457,10 +476,14 @@ async def continuous_rollouts(): response_len=max_res_tokens, target=target, ) + responses = await policy.generate.choose(prompt) + # print("phil_resp", responses) + for episode, response in zip(group.episodes, responses.outputs): episode.request_tokens = responses.prompt_token_ids episode.response_tokens = response.token_ids + assert len(response.token_ids) <= max_res_tokens episode.ref_logprobs = await ref_model.forward.choose(episode) episode.reward = await reward_actor.evaluate_response.choose( prompt=prompt, response=response.text, target=target @@ -470,6 +493,8 @@ async def continuous_rollouts(): episode.advantage = advantage await replay_buffer.add.choose(episode) + # exit() + rollout_count += 1 if rollout_count % 10 == 0: avg_reward = sum(e.reward for e in group.episodes) / len(group.episodes) @@ -487,6 +512,7 @@ async def continuous_training(): else: training_result = await trainer.train_step.choose(batch) training_step += 1 + exit() if training_step % 10 == 0: print(f"Completed {training_step} training steps") if training_result: From bf31587592e48972f7feabe90f6e39429d2a8eff Mon Sep 17 00:00:00 2001 From: joecummings Date: Thu, 4 Sep 2025 12:22:56 -0700 Subject: [PATCH 10/31] It runs --- apps/grpo/main.py | 59 ++++++++++++++--------------------------------- 1 file changed, 17 insertions(+), 42 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 52bfeca46..b20649220 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -53,33 +53,23 @@ def __init__(self, epsilon=0.1, beta=0.1): self.beta = beta def forward(self, logprobs, ref_logprobs, advantages, padding_mask): - per_token_kl = ( - torch.exp(ref_logprobs.detach() - logprobs) - - (ref_logprobs.detach() - logprobs) - - 1 - ) + # KL divergence: exp(ref - log) - (ref - log) - 1 + logprob_diff = ref_logprobs.detach() - logprobs + per_token_kl = torch.exp(logprob_diff) - logprob_diff - 1 - advantages = advantages[:, None] # [B x G, 1] + # Policy loss: advantages (logprobs - logprobs.detach() cancels to 0, so exp(0) = 1) + per_token_policy_loss = advantages - per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages + # Combined loss: -(policy_loss - beta * kl) per_token_loss = -(per_token_policy_loss - self.beta * per_token_kl) - loss = (per_token_loss * padding_mask).sum(dim=1) / ( - padding_mask.sum(dim=1) + 1e-8 - ).mean() - - return loss - # log_ratio = ref_logprobs.detach() - logprobs - # kl = torch.exp(log_ratio) - log_ratio - 1 - - # pl = torch.exp(logprobs - logprobs.detach()) * advantages - # loss = -pl + self.beta * kl - - # print(loss.shape, padding_mask.shape) - - # # Compute mean - # loss = (loss * padding_mask).sum() / (padding_mask.sum() + 1e-8) - # return loss + # Masked average + return ( + (per_token_loss * padding_mask) + .sum(dim=1) + .div(padding_mask.sum(dim=1) + 1e-8) + .mean() + ) @dataclass @@ -200,7 +190,7 @@ async def train_step(self, batch: list[Episode]): response = torch.stack(response).to(self.device) # [b x s] ref_logprobs = [e.ref_logprobs for e in batch] - ref_logprobs = torch.stack(ref_logprobs).to(self.device) # [b x s x d] + ref_logprobs = torch.stack(ref_logprobs).to(self.device).squeeze() # [b x s] advantages = [e.advantage for e in batch] advantages = torch.tensor(advantages).to(self.device).unsqueeze(-1) # [b x 1] @@ -220,8 +210,8 @@ async def train_step(self, batch: list[Episode]): self.optimizer.zero_grad() loss.backward() - # Gradient clipping (optional but recommended for stability) - torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) + # # Gradient clipping (optional but recommended for stability) + # torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) self.optimizer.step() @@ -270,7 +260,6 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl for reward_fn in self.reward_functions: reward = reward_fn(prompt, response, target) total_reward += reward - print("phil_rew_0", total_reward) return total_reward @@ -281,12 +270,9 @@ class ComputeAdvantages(ForgeActor): async def compute(self, group: Group) -> list[float]: # TODO: add batch processing rewards = torch.Tensor([[e.reward for e in group.episodes]]) - print("rewards", rewards) mean = rewards.mean(1, keepdim=True) std = rewards.std(1, keepdim=True) - print(std) - # if std is nan, return 0s. Remove this before shipping if std.isnan().any(): advantages = torch.zeros_like(rewards) @@ -294,7 +280,6 @@ async def compute(self, group: Group) -> list[float]: advantages = (rewards - mean) / (std + 1e-4) x = advantages.squeeze(0).tolist() - print("phil_adv_0", x) return x @@ -319,23 +304,15 @@ def __init__(self, model_name, device: torch.device | None = None): @endpoint async def forward(self, episode: Episode) -> torch.Tensor: - # Convert tokens to tensor req, res = episode.request_tensor, episode.response_tensor input_ids = torch.cat([req, res]).to(self.device).unsqueeze(0) mask = input_ids != episode.pad_id - # Compute logits with torch.inference_mode(): logits = self.model(input_ids=input_ids, attention_mask=mask).logits - # Compute logprobs input_ids = input_ids[:, len(req) :] - print("phil_ref_0", input_ids.shape) - print("phil_ref_1", logits.shape) - logprobs = compute_logprobs(logits, input_ids) - print("phil_ref_2", logprobs.shape) - - return logprobs + return compute_logprobs(logits, input_ids) @dataclass @@ -478,7 +455,6 @@ async def continuous_rollouts(): ) responses = await policy.generate.choose(prompt) - # print("phil_resp", responses) for episode, response in zip(group.episodes, responses.outputs): episode.request_tokens = responses.prompt_token_ids @@ -512,7 +488,6 @@ async def continuous_training(): else: training_result = await trainer.train_step.choose(batch) training_step += 1 - exit() if training_step % 10 == 0: print(f"Completed {training_step} training steps") if training_result: From 53c8c897c468c7d69573f57ef1d11ecb7c39cfab Mon Sep 17 00:00:00 2001 From: joecummings Date: Thu, 4 Sep 2025 12:23:49 -0700 Subject: [PATCH 11/31] Add in ref --- apps/grpo/main.py | 31 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index b20649220..81f81778b 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -45,7 +45,9 @@ def compute_logprobs( class SimpleGRPOLoss(nn.Module): - """Simplified GRPO Loss for simplified single step updates""" + """Simplified GRPO Loss for simplified single step updates + Copied from https://github.com/pytorch/torchtune/blob/main/torchtune/dev/grpo/loss.py. + """ def __init__(self, epsilon=0.1, beta=0.1): super().__init__() @@ -53,23 +55,18 @@ def __init__(self, epsilon=0.1, beta=0.1): self.beta = beta def forward(self, logprobs, ref_logprobs, advantages, padding_mask): - # KL divergence: exp(ref - log) - (ref - log) - 1 - logprob_diff = ref_logprobs.detach() - logprobs - per_token_kl = torch.exp(logprob_diff) - logprob_diff - 1 - - # Policy loss: advantages (logprobs - logprobs.detach() cancels to 0, so exp(0) = 1) - per_token_policy_loss = advantages - - # Combined loss: -(policy_loss - beta * kl) - per_token_loss = -(per_token_policy_loss - self.beta * per_token_kl) - - # Masked average - return ( - (per_token_loss * padding_mask) - .sum(dim=1) - .div(padding_mask.sum(dim=1) + 1e-8) - .mean() + per_token_kl = ( + torch.exp(ref_logprobs.detach() - logprobs) + - (ref_logprobs.detach() - logprobs) + - 1 ) + per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages + per_token_loss = -(per_token_policy_loss - self.beta * per_token_kl) + loss = ( + (per_token_loss * padding_mask).sum(dim=1) + / (padding_mask.sum(dim=1) + 1e-8) + ).mean() + return loss @dataclass From a13a1ac0661323cebf1a7f95c527085694bf0305 Mon Sep 17 00:00:00 2001 From: joecummings Date: Thu, 4 Sep 2025 12:30:02 -0700 Subject: [PATCH 12/31] Pass linting? --- apps/vllm/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/vllm/main.py b/apps/vllm/main.py index d079cbf96..0e438a28a 100644 --- a/apps/vllm/main.py +++ b/apps/vllm/main.py @@ -16,7 +16,7 @@ from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig from forge.controller.service import ServiceConfig, shutdown_service, spawn_service -from vllm.outputs import CompletionOutput, RequestOutput +from vllm.outputs import RequestOutput from vllm.transformers_utils.tokenizer import get_tokenizer From 833a6b6becd4b1cff220cb5af3ae0b01e4599b89 Mon Sep 17 00:00:00 2001 From: joecummings Date: Thu, 4 Sep 2025 12:56:56 -0700 Subject: [PATCH 13/31] Remove extraneous 'calculations' --- apps/grpo/main.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 81f81778b..d2fedae3a 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -174,10 +174,7 @@ def setup(self): @endpoint async def train_step(self, batch: list[Episode]): - total_loss = 0.0 - num_episodes_processed = 0 pad_id = batch[0].pad_id - bsz = len(batch) # prepare batch request = [e.request_tensor for e in batch] @@ -212,10 +209,7 @@ async def train_step(self, batch: list[Episode]): self.optimizer.step() - total_loss += loss.item() - avg_loss = total_loss / bsz - - return {"loss": avg_loss, "episodes_processed": num_episodes_processed} + return {"loss": loss.item()} @endpoint async def update_weights(self, policy_actor): From 0acbe4aa42b6783d38394b2ec4d4f77a5ab86ad2 Mon Sep 17 00:00:00 2001 From: joecummings Date: Thu, 4 Sep 2025 12:58:28 -0700 Subject: [PATCH 14/31] Stub out push weights --- apps/grpo/main.py | 27 ++------------------------- 1 file changed, 2 insertions(+), 25 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index d2fedae3a..da1c6a55d 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -212,31 +212,8 @@ async def train_step(self, batch: list[Episode]): return {"loss": loss.item()} @endpoint - async def update_weights(self, policy_actor): - """Update policy model weights with trainer's current weights.""" - # Time how long it takes to update weights - start_time = time.time() - - # Set model to eval mode for weight extraction - self.model.eval() - - # Extract current model state dict - model_state_dict = self.model.state_dict() - - # Convert tensors to CPU for transfer (if they're on GPU) - cpu_state_dict = {} - for key, tensor in model_state_dict.items(): - cpu_state_dict[key] = tensor.cpu() if tensor.is_cuda else tensor - - # Update the policy actor's model weights - await policy_actor.update_model_weights.choose(cpu_state_dict) - - # Set model back to training mode - self.model.train() - - # Log the time taken - end_time = time.time() - self.logger.info(f"Updating weights took {end_time - start_time:.2f} seconds") + async def push_weights(self): + pass @dataclass From 7d05aad7842c06b636a0945e28f7f48779dbbb87 Mon Sep 17 00:00:00 2001 From: joecummings Date: Thu, 4 Sep 2025 13:16:22 -0700 Subject: [PATCH 15/31] Remove tokenizer, add back in formatting --- apps/grpo/main.py | 13 ++++++------- src/forge/actors/policy.py | 1 - 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index da1c6a55d..43b5985a4 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -6,7 +6,6 @@ import asyncio import logging -import time import uuid from dataclasses import dataclass from typing import Any, Callable, Optional @@ -299,14 +298,14 @@ def setup(self): def gsm8k_transform(sample): request: str = sample["question"] - # formatted_request = self.tokenizer.apply_chat_template( - # [{"role": "user", "content": request}], - # tokenize=False, - # add_generation_prompt=True, - # ) + formatted_request = self.tokenizer.apply_chat_template( + [{"role": "user", "content": request}], + tokenize=False, + add_generation_prompt=True, + ) target: str = sample["answer"] formatted_target = target.split("#### ")[1] - return {"request": request, "target": formatted_target} + return {"request": formatted_request, "target": formatted_target} ds = load_dataset( self.path, self.revision, split=self.data_split, streaming=self.streaming diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 9cfb47b46..77a3f0942 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -250,7 +250,6 @@ async def generate(self, prompt: str, priority: int = 0) -> List[CompletionOutpu priority=priority, data_parallel_rank=None, ) - tokenizer = self.processor.input_preprocessor.get_tokenizer_group() # Explicitly keeping the redundant logic to make it easier to pick up # vllm changes From 3c880dd8075a11d280feca57e3622391ac9a82f7 Mon Sep 17 00:00:00 2001 From: joecummings Date: Thu, 4 Sep 2025 13:37:05 -0700 Subject: [PATCH 16/31] Cleanup --- apps/grpo/main.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 43b5985a4..9114c9100 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -329,7 +329,7 @@ async def pad_token(self): async def main(): """Main GRPO training loop with rollout and training processes.""" group_size = 1 - model = "Qwen/Qwen3-1.7B" + model = "Qwen/Qwen3-1.7B-Base" max_req_tokens = 512 max_res_tokens = 128 @@ -436,8 +436,6 @@ async def continuous_rollouts(): episode.advantage = advantage await replay_buffer.add.choose(episode) - # exit() - rollout_count += 1 if rollout_count % 10 == 0: avg_reward = sum(e.reward for e in group.episodes) / len(group.episodes) From 8796fa126bda2eba297c579599bfac5f0beea610 Mon Sep 17 00:00:00 2001 From: joecummings Date: Thu, 4 Sep 2025 14:57:39 -0700 Subject: [PATCH 17/31] Working w/ weight sync --- apps/grpo/main.py | 39 ++++++++++++++++++++++++++++++++------ src/forge/actors/policy.py | 14 ++++++++------ 2 files changed, 41 insertions(+), 12 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 9114c9100..828bd1d76 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -6,6 +6,7 @@ import asyncio import logging +import time import uuid from dataclasses import dataclass from typing import Any, Callable, Optional @@ -21,11 +22,13 @@ from forge.util.metric_logging import get_metric_logger from monarch.actor import endpoint from torch import nn +from torchstore import MultiProcessStore +from torchstore._state_dict_utils import DELIM, push_state_dict from transformers import AutoModelForCausalLM from vllm.transformers_utils.tokenizer import get_tokenizer logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) +logger.setLevel(logging.INFO) def compute_logprobs( @@ -145,6 +148,8 @@ class Trainer(ForgeActor): beta: float = 0.1 epsilon: float = 0.1 device: torch.device | None = None + store: MultiProcessStore | None = None + state_dict_key: str = "model_state_dict" @endpoint def setup(self): @@ -211,8 +216,17 @@ async def train_step(self, batch: list[Episode]): return {"loss": loss.item()} @endpoint - async def push_weights(self): - pass + async def push_weights(self, version: int): + """Update policy model weights with trainer's current weights.""" + start_time = time.time() + assert self.store is not None, "Store must be provided to save weights" + await push_state_dict( + self.store, + self.model.state_dict(), + f"{self.state_dict_key}{DELIM}{version}", # Use version as key + ) + end_time = time.time() + self.logger.info(f"Updating weights took {end_time - start_time:.2f} seconds") @dataclass @@ -340,6 +354,8 @@ async def main(): project="grpo-training", ) + store = await MultiProcessStore.create_store() + # ---- Setup services ---- # ( dataloader, @@ -368,12 +384,14 @@ async def main(): n=group_size, max_tokens=max_res_tokens ), ), + store=store, ), spawn_service( ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1), Trainer, learning_rate=1e-5, model_name=model, + store=store, ), spawn_service( ServiceConfig(procs_per_replica=1, num_replicas=1), @@ -409,7 +427,7 @@ async def continuous_rollouts(): print("Dataloader is empty, exiting continuous rollout") return prompt, target = sample["request"], sample["target"] - version = 0 # await policy.get_current_version.choose() + version = await policy.get_version.choose() group = Group.new_group( group_id=rollout_count, group_size=group_size, @@ -446,8 +464,11 @@ async def continuous_rollouts(): async def continuous_training(): training_step = 0 + policy_version = 0 while True: - batch = await replay_buffer.sample.choose(curr_policy_version=0) + batch = await replay_buffer.sample.choose( + curr_policy_version=policy_version + ) if batch is None: await asyncio.sleep(0.1) else: @@ -459,7 +480,13 @@ async def continuous_training(): loss_value = training_result.get("loss", 0.0) print(f"Latest loss: {loss_value}") logger.log("loss/training_step", loss_value, training_step) - # await trainer.update_weights(policy) + start_time = time.time() + await trainer.push_weights.choose(policy_version) + print(f"Updating weights took {time.time() - start_time}") + start_time = time.time() + _ = await policy.update_weights.choose() + print(f"Updating policy took {time.time() - start_time}") + policy_version += 1 print("Starting GRPO training loops...") # TODO: Start multiple rollouts once all serivces support it diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 77a3f0942..33c21615e 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -15,7 +15,7 @@ import torch from monarch.actor import current_rank, endpoint, ProcMesh from torchstore import MultiProcessStore -from torchstore._state_dict_utils import DELIM +from torchstore._state_dict_utils import DELIM, get_state_dict from vllm.engine.arg_utils import EngineArgs from vllm.entrypoints.utils import _validate_truncation_size @@ -325,9 +325,8 @@ async def update_weights(self) -> int: futures = [fut for _, fut in self.requests.values()] if futures: await asyncio.gather(*futures) - new_version = self.weights_version + 1 - await self.policy_worker.update.call(version=new_version) - self.weights_version = new_version + await self.policy_worker.update.call(version=self.weights_version) + self.weights_version += 1 return self.weights_version @endpoint @@ -439,9 +438,12 @@ async def update(self, version: int): ) model = self.worker.model_runner.model - current_state_dict = model.state_dict() + new_state_dict = await get_state_dict( + self.torchstore, f"{self.state_dict_key}{DELIM}{version}" + ) + model.load_weights(list(new_state_dict.items())) - await self._load_tensor_parallel_state_dict(current_state_dict, version) + # await self._load_tensor_parallel_state_dict(current_state_dict, version) logger.debug("Successfully updated model weights from torchstore") @endpoint From 75447d9fa163856f0bbbd83cae0cc78f6b1571b9 Mon Sep 17 00:00:00 2001 From: joecummings Date: Fri, 5 Sep 2025 13:37:34 -0700 Subject: [PATCH 18/31] stub --- apps/grpo/main.py | 56 +++++++++++------------------ src/forge/actors/policy.py | 52 ++++++++++++++++++--------- src/forge/actors/reference_actor.py | 2 -- 3 files changed, 55 insertions(+), 55 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 828bd1d76..2b93be9af 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -27,9 +27,6 @@ from transformers import AutoModelForCausalLM from vllm.transformers_utils.tokenizer import get_tokenizer -logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) - def compute_logprobs( logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0 @@ -213,7 +210,7 @@ async def train_step(self, batch: list[Episode]): self.optimizer.step() - return {"loss": loss.item()} + return loss.item() @endpoint async def push_weights(self, version: int): @@ -240,6 +237,9 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl total_reward = 0.0 for reward_fn in self.reward_functions: reward = reward_fn(prompt, response, target) + self.logger.info( + f"Response: {response} | Target: {target} | Reward: {reward}" + ) total_reward += reward return total_reward @@ -253,15 +253,8 @@ async def compute(self, group: Group) -> list[float]: rewards = torch.Tensor([[e.reward for e in group.episodes]]) mean = rewards.mean(1, keepdim=True) std = rewards.std(1, keepdim=True) - - # if std is nan, return 0s. Remove this before shipping - if std.isnan().any(): - advantages = torch.zeros_like(rewards) - else: - advantages = (rewards - mean) / (std + 1e-4) - - x = advantages.squeeze(0).tolist() - return x + advantages = (rewards - mean) / (std + 1e-4) + return advantages.squeeze(0).tolist() class RefModel(ForgeActor): @@ -342,8 +335,8 @@ async def pad_token(self): async def main(): """Main GRPO training loop with rollout and training processes.""" - group_size = 1 - model = "Qwen/Qwen3-1.7B-Base" + group_size = 5 + model = "Qwen/Qwen3-4B-Base" max_req_tokens = 512 max_res_tokens = 128 @@ -438,13 +431,15 @@ async def continuous_rollouts(): response_len=max_res_tokens, target=target, ) - responses = await policy.generate.choose(prompt) - + if responses is None: + continue for episode, response in zip(group.episodes, responses.outputs): episode.request_tokens = responses.prompt_token_ids episode.response_tokens = response.token_ids - assert len(response.token_ids) <= max_res_tokens + logger.log( + "response_len/rollout", len(response.token_ids), rollout_count + ) episode.ref_logprobs = await ref_model.forward.choose(episode) episode.reward = await reward_actor.evaluate_response.choose( prompt=prompt, response=response.text, target=target @@ -453,14 +448,12 @@ async def continuous_rollouts(): for episode, advantage in zip(group.episodes, advantages): episode.advantage = advantage await replay_buffer.add.choose(episode) + buffer_size = await replay_buffer._numel.choose() + logger.log("buffer_size/rollout", buffer_size, rollout_count) rollout_count += 1 - if rollout_count % 10 == 0: - avg_reward = sum(e.reward for e in group.episodes) / len(group.episodes) - print( - f"Generated {rollout_count} rollouts w/ average reward {avg_reward}" - ) - logger.log("reward/rollout", avg_reward, rollout_count) + avg_reward = sum(e.reward for e in group.episodes) / len(group.episodes) + logger.log("reward/rollout", avg_reward, rollout_count) async def continuous_training(): training_step = 0 @@ -472,21 +465,12 @@ async def continuous_training(): if batch is None: await asyncio.sleep(0.1) else: - training_result = await trainer.train_step.choose(batch) + loss = await trainer.train_step.choose(batch) training_step += 1 - if training_step % 10 == 0: - print(f"Completed {training_step} training steps") - if training_result: - loss_value = training_result.get("loss", 0.0) - print(f"Latest loss: {loss_value}") - logger.log("loss/training_step", loss_value, training_step) - start_time = time.time() + logger.log("loss/training_step", loss, training_step) await trainer.push_weights.choose(policy_version) - print(f"Updating weights took {time.time() - start_time}") - start_time = time.time() - _ = await policy.update_weights.choose() - print(f"Updating policy took {time.time() - start_time}") policy_version += 1 + _ = await policy.update_weights.choose() print("Starting GRPO training loops...") # TODO: Start multiple rollouts once all serivces support it diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 33c21615e..b9d7dc558 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -13,6 +13,12 @@ from typing import Dict, List import torch + +from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh + +from forge.data.sharding import VLLMSharding +from forge.interfaces import Policy as PolicyInterface +from forge.types import ProcessConfig from monarch.actor import current_rank, endpoint, ProcMesh from torchstore import MultiProcessStore from torchstore._state_dict_utils import DELIM, get_state_dict @@ -37,15 +43,6 @@ from vllm.v1.structured_output import StructuredOutputManager from vllm.worker.worker_base import WorkerWrapperBase -from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh - -from forge.data.sharding import VLLMSharding -from forge.interfaces import Policy as PolicyInterface -from forge.types import ProcessConfig - - -logger = logging.getLogger(__name__) - @dataclass class SamplingOverrides: @@ -279,7 +276,11 @@ async def generate(self, prompt: str, priority: int = 0) -> List[CompletionOutpu request_fut = asyncio.Future() self.requests[request_id] = (parent_req, request_fut) - return await request_fut + try: + generations = await request_fut + return generations + except asyncio.exceptions.CancelledError: + self.logger.debug(f"Request {request_id} was cancelled") # Abstracted to match vllm # https://github.com/vllm-project/vllm/blob/0e3bb543f064eb416bca4f6f3013efa3830b12f7/vllm/v1/engine/core.py#L419 @@ -321,10 +322,27 @@ async def run(self): @endpoint async def update_weights(self) -> int: """Update the policy weights.""" - # Wait for all current requests to finish, then publish model weights - futures = [fut for _, fut in self.requests.values()] - if futures: - await asyncio.gather(*futures) + # Cancel all current requests and wait for them to finish + pending_futures = [] + for request_id, (parent_req, fut) in list(self.requests.items()): + if not fut.done(): + fut.cancel("Received weight update, cancelling request") + pending_futures.append(fut) + + # Wait for all cancelled requests to finish with a timeout + if pending_futures: + try: + await asyncio.wait_for( + asyncio.gather(*pending_futures, return_exceptions=True), + timeout=5.0, # 5 second timeout + ) + except asyncio.TimeoutError: + logging.warning("Some requests did not cancel within timeout") + + # Clear any remaining requests + self.requests.clear() + + # Now update the weights await self.policy_worker.update.call(version=self.weights_version) self.weights_version += 1 return self.weights_version @@ -382,7 +400,7 @@ def __post_init__(self): for key in cfg: value = getattr(self, key) if key != "data_parallel_size" else 1 if getattr(self.vllm_args, key) != value: - logger.warning( + self.logger.warning( f"{key} args don't match value in EngineArgs, overriding with {value}" ) setattr(self.vllm_args, key, value) @@ -433,7 +451,7 @@ async def update(self, version: int): if self.torchstore is None: raise Exception("No torchstore configured, skipping model update") - logger.debug( + self.logger.debug( f"Starting model update from torchstore with key: {self.state_dict_key}{DELIM}{version}" ) @@ -444,7 +462,7 @@ async def update(self, version: int): model.load_weights(list(new_state_dict.items())) # await self._load_tensor_parallel_state_dict(current_state_dict, version) - logger.debug("Successfully updated model weights from torchstore") + self.logger.debug("Successfully updated model weights from torchstore") @endpoint async def setup_kv_cache(self): diff --git a/src/forge/actors/reference_actor.py b/src/forge/actors/reference_actor.py index c0b6aad24..120a2228e 100644 --- a/src/forge/actors/reference_actor.py +++ b/src/forge/actors/reference_actor.py @@ -17,8 +17,6 @@ from typing import Any import torch - -from forge.controller import ForgeActor from monarch.actor import current_rank, current_size, endpoint from omegaconf import DictConfig, OmegaConf from torch import nn From 31201000606e67a8882da029ab8b61fa008172a0 Mon Sep 17 00:00:00 2001 From: joecummings Date: Mon, 8 Sep 2025 16:17:12 -0700 Subject: [PATCH 19/31] Queue while updating weights --- apps/grpo/main.py | 127 ++++++++----------------------------- src/forge/actors/policy.py | 119 ++++++++++++++++++++++++++-------- 2 files changed, 119 insertions(+), 127 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 6b19c7380..e26a6e9b8 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. import asyncio -import logging import time import uuid from dataclasses import dataclass @@ -28,81 +27,6 @@ from vllm.transformers_utils.tokenizer import get_tokenizer -def compute_logprobs( - logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0 -) -> torch.Tensor: - context_length = logits.shape[1] - input_ids.shape[1] - - # Truncate request logits and drop last - logits = logits[:, context_length - 1 : -1] - - # Compute logprobs - logprobs = torch.log_softmax(logits / temperature, dim=-1) - logprobs = torch.gather(logprobs, 2, input_ids.unsqueeze(-1)).squeeze(-1) - - return logprobs - - -class SimpleGRPOLoss(nn.Module): - """Simplified GRPO Loss for simplified single step updates - Copied from https://github.com/pytorch/torchtune/blob/main/torchtune/dev/grpo/loss.py. - """ - - def __init__(self, epsilon=0.1, beta=0.1): - super().__init__() - self.epsilon = epsilon - self.beta = beta - - def forward(self, logprobs, ref_logprobs, advantages, padding_mask): - per_token_kl = ( - torch.exp(ref_logprobs.detach() - logprobs) - - (ref_logprobs.detach() - logprobs) - - 1 - ) - per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages - per_token_loss = -(per_token_policy_loss - self.beta * per_token_kl) - loss = ( - (per_token_loss * padding_mask).sum(dim=1) - / (padding_mask.sum(dim=1) + 1e-8) - ).mean() - return loss - - -@dataclass -class Episode: - # TODO: add adtional layer for multi-turn - episode_id: str - request: str - policy_version: int - pad_id: int - request_len: int - response_len: int - target: Optional[Any] = None - # processed data - response: Optional[str] = None - request_tokens: Optional[list[int]] = None - response_tokens: Optional[list[int]] = None - ref_logprobs: Optional[torch.Tensor] = None - reward: Optional[float] = None - advantage: Optional[float] = None - - @property - def request_tensor(self): - tensor = torch.tensor(self.request_tokens, dtype=torch.long) - if tensor.shape[0] < self.request_len: # left pad - diff = self.request_len - tensor.shape[0] - tensor = F.pad(tensor, (diff, 0), value=self.pad_id) - return tensor - - @property - def response_tensor(self): - tensor = torch.tensor(self.response_tokens, dtype=torch.long) - if tensor.shape[0] < self.response_len: # right pad - diff = self.response_len - tensor.shape[0] - tensor = F.pad(tensor, (0, diff), value=self.pad_id) - return tensor - - def compute_logprobs( logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0 ) -> torch.Tensor: @@ -196,7 +120,7 @@ def new_group( target: Any = None, ): episodes = [] - for i in range(group_size): + for _ in range(group_size): episodes.append( Episode( episode_id=str(uuid.uuid4()), @@ -292,13 +216,12 @@ async def push_weights(self, version: int): """Update policy model weights with trainer's current weights.""" start_time = time.time() assert self.store is not None, "Store must be provided to save weights" - await push_state_dict( - self.store, - self.model.state_dict(), - f"{self.state_dict_key}{DELIM}{version}", # Use version as key - ) + key = f"{self.state_dict_key}{DELIM}{version}" # Use version as unique id + await push_state_dict(self.store, self.model.state_dict(), key) end_time = time.time() - self.logger.info(f"Updating weights took {end_time - start_time:.2f} seconds") + self.logger.debug( + f"Pushed weights to {key} in {end_time - start_time:.2f} seconds" + ) @dataclass @@ -410,10 +333,10 @@ async def pad_token(self): async def main(): """Main GRPO training loop with rollout and training processes.""" - group_size = 4 - model = "Qwen/Qwen3-1.7B-Base" + group_size = 5 + model = "Qwen/Qwen3-4B-Base" max_req_tokens = 512 - max_res_tokens = 128 + max_res_tokens = 512 # ---- Setup WandB Logger ---- # logger = get_metric_logger( @@ -464,8 +387,8 @@ async def main(): spawn_service( ServiceConfig(procs_per_replica=1, num_replicas=1), ReplayBuffer, - batch_size=4, - max_policy_age=1, + batch_size=8, + max_policy_age=0, # Fully on-policy for now ), spawn_service( ServiceConfig(procs_per_replica=1, num_replicas=1), @@ -495,6 +418,12 @@ async def continuous_rollouts(): print("Dataloader is empty, exiting continuous rollout") return prompt, target = sample["request"], sample["target"] + responses = await policy.generate.choose(prompt) + # If weights are updating mid-rollout, response will be cancelled and service + # will return None. We currently throw away the sample. + if responses is None: + continue + version = await policy.get_version.choose() group = Group.new_group( group_id=rollout_count, @@ -507,15 +436,10 @@ async def continuous_rollouts(): target=target, ) - responses = await policy.generate.choose(prompt) - if responses is None: - continue + # TODO: Parallelize the following calculation for episode, response in zip(group.episodes, responses.outputs): episode.request_tokens = responses.prompt_token_ids episode.response_tokens = response.token_ids - logger.log( - "response_len/rollout", len(response.token_ids), rollout_count - ) episode.ref_logprobs = await ref_model.forward.choose(episode) episode.reward = await reward_actor.evaluate_response.choose( prompt=prompt, response=response.text, target=target @@ -524,12 +448,17 @@ async def continuous_rollouts(): for episode, advantage in zip(group.episodes, advantages): episode.advantage = advantage await replay_buffer.add.choose(episode) - buffer_size = await replay_buffer._numel.choose() - logger.log("buffer_size/rollout", buffer_size, rollout_count) + + avg_response_len = ( + sum(len(e.response_tokens) for e in group.episodes) / group_size + ) + logger.log("avg_response_len/rollout", avg_response_len, rollout_count) + buffer_size = await replay_buffer._numel.choose() + logger.log("buffer_size/rollout", buffer_size, rollout_count) + avg_reward = sum(e.reward for e in group.episodes) / group_size + logger.log("avg_reward/rollout", avg_reward, rollout_count) rollout_count += 1 - avg_reward = sum(e.reward for e in group.episodes) / len(group.episodes) - logger.log("reward/rollout", avg_reward, rollout_count) async def continuous_training(): training_step = 0 @@ -546,7 +475,7 @@ async def continuous_training(): logger.log("loss/training_step", loss, training_step) await trainer.push_weights.choose(policy_version) policy_version += 1 - _ = await policy.update_weights.choose() + await policy.update_weights.choose() print("Starting GRPO training loops...") # TODO: Start multiple rollouts once all serivces support it diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 5cf6df600..66815051f 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -5,12 +5,11 @@ # LICENSE file in the root directory of this source tree. import asyncio -import logging import os import sys +import time from copy import copy from dataclasses import asdict, dataclass, field -from typing import Dict, List import torch from monarch.actor import current_rank, endpoint, ProcMesh @@ -21,7 +20,7 @@ from vllm.entrypoints.utils import _validate_truncation_size from vllm.executor.multiproc_worker_utils import set_multiprocessing_worker_envs from vllm.lora.request import LoRARequest -from vllm.outputs import CompletionOutput +from vllm.outputs import RequestOutput from vllm.sampling_params import GuidedDecodingParams, RequestOutputKind, SamplingParams from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.usage.usage_lib import UsageContext @@ -112,6 +111,9 @@ def __post_init__(self): self._policy_proc: ProcMesh | None = None self._worker_procs: ProcMesh | None = None self.weights_version: int = 0 + self._updating_weights: bool = False + self._request_queue: list[tuple[str, int, asyncio.Future]] = [] + self.running: bool = False @classmethod async def launch( # pyright: ignore[reportIncompatibleMethodOverride] @@ -174,7 +176,7 @@ async def setup(self): await self.policy_worker.setup.call(store=self.store) self.request_id = 0 - self.requests: Dict[str, tuple[None | ParentRequest, asyncio.Future]] = {} + self.requests: dict[str, tuple[None | ParentRequest, asyncio.Future]] = {} self.vllm_args = await self.policy_worker.get_vllm_args.choose() # Setup sampling params @@ -218,12 +220,21 @@ def start_processing(self): self._run_task = asyncio.create_task(self.run()) @endpoint - async def generate(self, prompt: str, priority: int = 0) -> List[CompletionOutput]: + async def generate(self, prompt: str, priority: int = 0) -> RequestOutput | None: + """Generate a response for the given prompt.""" + if self._updating_weights: + request_future = asyncio.Future() + self._request_queue.append((prompt, priority, request_future)) + return await request_future + return await self._generate(prompt, priority) + + async def _generate(self, prompt: str, priority: int = 0) -> RequestOutput | None: + """Internal generation method that doesn't check for weight updates.""" self.request_id += 1 % sys.maxsize request_id = str(self.request_id) # implement from a counter # Wraps prompt into a dict - prompt: Dict[str, str] = convert_input(prompt=prompt) + prompt_dict: dict[str, str] = convert_input(prompt=prompt) # truncate prmpt tokenization_kwargs = self.tokenization_kwargs or {} @@ -238,7 +249,7 @@ async def generate(self, prompt: str, priority: int = 0) -> List[CompletionOutpu # process and tokenize prompt prompt_str, request = self.processor.process_inputs( request_id=request_id, - prompt=prompt, + prompt=prompt_dict, params=self.sampling_params, arrival_time=None, lora_request=self.lora_request, @@ -276,11 +287,15 @@ async def generate(self, prompt: str, priority: int = 0) -> List[CompletionOutpu request_fut = asyncio.Future() self.requests[request_id] = (parent_req, request_fut) + # Yield control to allow the run() loop to process the scheduled request + await asyncio.sleep(0) + try: generations = await request_fut return generations except asyncio.exceptions.CancelledError: self.logger.debug(f"Request {request_id} was cancelled") + return None # Abstracted to match vllm # https://github.com/vllm-project/vllm/blob/0e3bb543f064eb416bca4f6f3013efa3830b12f7/vllm/v1/engine/core.py#L419 @@ -314,14 +329,27 @@ async def run(self): engine_core_timestamp=outputs.timestamp, iteration_stats=None, ) + for request_output in processed_outputs.request_outputs: if request_output.finished: - _, fut = self.requests.pop(request_output.request_id) - fut.set_result(request_output) + if request_output.request_id in self.requests: + _, fut = self.requests.pop(request_output.request_id) + fut.set_result(request_output) @endpoint - async def update_weights(self) -> int: - """Update the policy weights.""" + async def update_weights(self): + """Update the policy weights and schedule processing of queued requests.""" + queued_count = len(self._request_queue) + self.logger.debug( + f"Starting weight update (v{self.weights_version} -> v{self.weights_version + 1})" + ) + if queued_count > 0: + self.logger.debug( + f"Will process {queued_count} queued requests after update" + ) + + self._updating_weights = True + # Cancel all current requests and wait for them to finish pending_futures = [] for request_id, (parent_req, fut) in list(self.requests.items()): @@ -331,24 +359,59 @@ async def update_weights(self) -> int: # Wait for all cancelled requests to finish with a timeout if pending_futures: + self.logger.debug(f"Cancelling {len(pending_futures)} pending requests") try: await asyncio.wait_for( asyncio.gather(*pending_futures, return_exceptions=True), - timeout=5.0, # 5 second timeout + timeout=5.0, ) except asyncio.TimeoutError: - logging.warning("Some requests did not cancel within timeout") + self.logger.warning("Some requests did not cancel within timeout") - # Clear any remaining requests self.requests.clear() - # Now update the weights - await self.policy_worker.update.call(version=self.weights_version) - self.weights_version += 1 - return self.weights_version + try: + await self.policy_worker.update.call(version=self.weights_version) + self.weights_version += 1 + self.logger.info(f"Weight update completed (now v{self.weights_version})") + except Exception as e: + self.logger.error(f"Weight update failed: {e}") + self._updating_weights = False + raise + + self._updating_weights = False + + # Schedule queue processing as a separate task to avoid blocking the endpoint + if self._request_queue: + task = asyncio.create_task(self._process_queued_requests()) + task.add_done_callback(self._queue_processing_callback) + + async def _process_queued_requests(self): + """Process all queued requests after weight update completes.""" + queued_requests = self._request_queue.copy() + self._request_queue.clear() + + for i, (prompt, priority, future) in enumerate(queued_requests): + try: + # Use the internal method directly to avoid the updating weights check + result = await self._generate(prompt, priority) + future.set_result(result) + except Exception as e: + self.logger.error(f"Error processing queued request {i+1}: {e}") + future.set_exception(e) + + def _queue_processing_callback(self, task: asyncio.Task): + """Callback to handle completion/errors of queue processing task.""" + try: + if task.exception(): + self.logger.error(f"Queue processing task failed: {task.exception()}") + else: + self.logger.debug("Queue processing task completed successfully") + except Exception as e: + self.logger.error(f"Error in queue processing callback: {e}") @endpoint - async def _get_model_params(self) -> Dict[str, torch.Tensor]: + async def _get_model_params(self) -> dict[str, torch.Tensor]: """Get the current model parameters. Only for testing purposes.""" model_params = await self.policy_worker._get_model_params.choose() return model_params @@ -457,18 +520,18 @@ async def update(self, version: int): if self.torchstore is None: raise Exception("No torchstore configured, skipping model update") - self.logger.debug( - f"Starting model update from torchstore with key: {self.state_dict_key}{DELIM}{version}" - ) - + key = f"{self.state_dict_key}{DELIM}{version}" model = self.worker.model_runner.model + start = time.time() new_state_dict = await get_state_dict( self.torchstore, f"{self.state_dict_key}{DELIM}{version}" ) + # We use the load_weights method from vLLM b/c they perform custom mapping like + # fusing QKV linear layers model.load_weights(list(new_state_dict.items())) - - # await self._load_tensor_parallel_state_dict(current_state_dict, version) - self.logger.debug("Successfully updated model weights from torchstore") + self.logger.debug( + f"Loaded state dict from {key} in {time.time() - start} seconds" + ) @endpoint async def setup_kv_cache(self): @@ -505,7 +568,7 @@ async def get_vllm_args(self): return self.vllm_args @endpoint - async def _get_model_params(self) -> Dict[str, torch.Tensor]: + async def _get_model_params(self) -> dict[str, torch.Tensor]: model = self.worker.model_runner.model state_dict = {} @@ -538,7 +601,7 @@ def setup_worker(self): return worker -def convert_input(prompt=None, prompt_token_ids=None) -> Dict: +def convert_input(prompt=None, prompt_token_ids=None) -> dict: assert (prompt is None) ^ (prompt_token_ids is None) if prompt is not None: return {"prompt": prompt} From 8f4bda1460c4a6f09d1f059f9bb639a514487051 Mon Sep 17 00:00:00 2001 From: joecummings Date: Tue, 9 Sep 2025 19:28:44 -0700 Subject: [PATCH 20/31] Cleanup --- apps/grpo/main.py | 89 +++++++++++++---------------- src/forge/actors/policy.py | 112 ++++++++----------------------------- 2 files changed, 60 insertions(+), 141 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index e26a6e9b8..31cbc9be5 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -8,7 +8,7 @@ import time import uuid from dataclasses import dataclass -from typing import Any, Callable, Optional +from typing import Any, Callable import torch import torch.nn.functional as F @@ -47,9 +47,8 @@ class SimpleGRPOLoss(nn.Module): Copied from https://github.com/pytorch/torchtune/blob/main/torchtune/dev/grpo/loss.py. """ - def __init__(self, epsilon=0.1, beta=0.1): + def __init__(self, beta: float = 0.1): super().__init__() - self.epsilon = epsilon self.beta = beta def forward(self, logprobs, ref_logprobs, advantages, padding_mask): @@ -76,14 +75,14 @@ class Episode: pad_id: int request_len: int response_len: int - target: Optional[Any] = None + target: Any | None = None # processed data - response: Optional[str] = None - request_tokens: Optional[list[int]] = None - response_tokens: Optional[list[int]] = None - ref_logprobs: Optional[torch.Tensor] = None - reward: Optional[float] = None - advantage: Optional[float] = None + response: str | None = None + request_tokens: list[int] | None = None + response_tokens: list[int] | None = None + ref_logprobs: torch.Tensor | None = None + reward: float | None = None + advantage: float | None = None @property def request_tensor(self): @@ -142,18 +141,15 @@ class Trainer(ForgeActor): model_name: str learning_rate: float = 1e-5 beta: float = 0.1 - epsilon: float = 0.1 device: torch.device | None = None store: MultiProcessStore | None = None state_dict_key: str = "model_state_dict" @endpoint def setup(self): - # Set device if self.device is None: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - # Initialize model self.model = AutoModelForCausalLM.from_pretrained( self.model_name, dtype=torch.bfloat16, @@ -161,16 +157,14 @@ def setup(self): ).to(self.device) self.model.train() - # Initialize optimizer self.optimizer = torch.optim.AdamW( self.model.parameters(), lr=self.learning_rate ) self.optimizer.zero_grad() - # Initialize loss - self.loss = SimpleGRPOLoss(self.epsilon, self.beta) + self.loss = SimpleGRPOLoss(self.beta) - self.logger.info(f"Model initialized on {self.device}") + self.logger.info(f"Trainer model initialized on {self.device}") @endpoint async def train_step(self, batch: list[Episode]): @@ -190,26 +184,19 @@ async def train_step(self, batch: list[Episode]): advantages = torch.tensor(advantages).to(self.device).unsqueeze(-1) # [b x 1] del batch - # compute policy logprobs input_ids = torch.cat([request, response], dim=1) mask = input_ids != pad_id logits = self.model(input_ids=input_ids, attention_mask=mask).logits logprobs = compute_logprobs(logits, response) del logits - # compute loss mask = response != pad_id loss = self.loss(logprobs, ref_logprobs, advantages, mask) - self.optimizer.zero_grad() loss.backward() - - # # Gradient clipping (optional but recommended for stability) - # torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) - self.optimizer.step() - return loss.item() + return loss.detach() @endpoint async def push_weights(self, version: int): @@ -232,14 +219,11 @@ class RewardActor(ForgeActor): @endpoint async def evaluate_response(self, prompt: str, response: str, target: str) -> float: - total_reward = 0.0 + total_rewards = 0.0 for reward_fn in self.reward_functions: reward = reward_fn(prompt, response, target) - self.logger.info( - f"Response: {response} | Target: {target} | Reward: {reward}" - ) - total_reward += reward - return total_reward + total_rewards += reward + return total_rewards / len(self.reward_functions) class ComputeAdvantages(ForgeActor): @@ -248,7 +232,7 @@ class ComputeAdvantages(ForgeActor): @endpoint async def compute(self, group: Group) -> list[float]: # TODO: add batch processing - rewards = torch.Tensor([[e.reward for e in group.episodes]]) + rewards = torch.tensor([[e.reward for e in group.episodes]]) mean = rewards.mean(1, keepdim=True) std = rewards.std(1, keepdim=True) advantages = (rewards - mean) / (std + 1e-4) @@ -302,9 +286,17 @@ def setup(self): self.tokenizer = get_tokenizer(self.model) def gsm8k_transform(sample): + system_prompt = """ + Put all your scratchpad work between and tags. + Your final answer should be between and tags otherwise it will not be scored. + """ request: str = sample["question"] + as_chat = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": request}, + ] formatted_request = self.tokenizer.apply_chat_template( - [{"role": "user", "content": request}], + as_chat, tokenize=False, add_generation_prompt=True, ) @@ -333,21 +325,20 @@ async def pad_token(self): async def main(): """Main GRPO training loop with rollout and training processes.""" - group_size = 5 - model = "Qwen/Qwen3-4B-Base" + model = "Qwen/Qwen3-1.7B" max_req_tokens = 512 max_res_tokens = 512 - - # ---- Setup WandB Logger ---- # - logger = get_metric_logger( + group_size = 8 + batch_size = 16 + max_policy_age = 0 # Fully on-policy + mlogger = get_metric_logger( "wandb", freq=1, project="grpo-training", ) - store = await MultiProcessStore.create_store() - # ---- Setup services ---- # + store = await MultiProcessStore.create_store() ( dataloader, policy, @@ -387,8 +378,8 @@ async def main(): spawn_service( ServiceConfig(procs_per_replica=1, num_replicas=1), ReplayBuffer, - batch_size=8, - max_policy_age=0, # Fully on-policy for now + batch_size=batch_size, + max_policy_age=max_policy_age, ), spawn_service( ServiceConfig(procs_per_replica=1, num_replicas=1), @@ -419,11 +410,6 @@ async def continuous_rollouts(): return prompt, target = sample["request"], sample["target"] responses = await policy.generate.choose(prompt) - # If weights are updating mid-rollout, response will be cancelled and service - # will return None. We currently throw away the sample. - if responses is None: - continue - version = await policy.get_version.choose() group = Group.new_group( group_id=rollout_count, @@ -440,6 +426,7 @@ async def continuous_rollouts(): for episode, response in zip(group.episodes, responses.outputs): episode.request_tokens = responses.prompt_token_ids episode.response_tokens = response.token_ids + episode.response = response.text episode.ref_logprobs = await ref_model.forward.choose(episode) episode.reward = await reward_actor.evaluate_response.choose( prompt=prompt, response=response.text, target=target @@ -452,11 +439,11 @@ async def continuous_rollouts(): avg_response_len = ( sum(len(e.response_tokens) for e in group.episodes) / group_size ) - logger.log("avg_response_len/rollout", avg_response_len, rollout_count) + mlogger.log("avg_response_len/rollout", avg_response_len, rollout_count) buffer_size = await replay_buffer._numel.choose() - logger.log("buffer_size/rollout", buffer_size, rollout_count) + mlogger.log("buffer_size/rollout", buffer_size, rollout_count) avg_reward = sum(e.reward for e in group.episodes) / group_size - logger.log("avg_reward/rollout", avg_reward, rollout_count) + mlogger.log("avg_reward/rollout", avg_reward, rollout_count) rollout_count += 1 @@ -472,7 +459,7 @@ async def continuous_training(): else: loss = await trainer.train_step.choose(batch) training_step += 1 - logger.log("loss/training_step", loss, training_step) + mlogger.log("loss/training_step", loss, training_step) await trainer.push_weights.choose(policy_version) policy_version += 1 await policy.update_weights.choose() diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 66815051f..b1966caed 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -111,8 +111,6 @@ def __post_init__(self): self._policy_proc: ProcMesh | None = None self._worker_procs: ProcMesh | None = None self.weights_version: int = 0 - self._updating_weights: bool = False - self._request_queue: list[tuple[str, int, asyncio.Future]] = [] self.running: bool = False @classmethod @@ -220,16 +218,19 @@ def start_processing(self): self._run_task = asyncio.create_task(self.run()) @endpoint - async def generate(self, prompt: str, priority: int = 0) -> RequestOutput | None: - """Generate a response for the given prompt.""" - if self._updating_weights: - request_future = asyncio.Future() - self._request_queue.append((prompt, priority, request_future)) - return await request_future + async def generate(self, prompt: str, priority: int = 0) -> RequestOutput: + """Generate a response for the given prompt + + Args: + prompt (str): The prompt to generate a response for. + priority (int, optional): The priority of the request. Defaults to 0. + + Returns: + RequestOutput: vLLM class with the generated response. + """ return await self._generate(prompt, priority) - async def _generate(self, prompt: str, priority: int = 0) -> RequestOutput | None: - """Internal generation method that doesn't check for weight updates.""" + async def _generate(self, prompt: str, priority: int = 0) -> RequestOutput: self.request_id += 1 % sys.maxsize request_id = str(self.request_id) # implement from a counter @@ -287,15 +288,7 @@ async def _generate(self, prompt: str, priority: int = 0) -> RequestOutput | Non request_fut = asyncio.Future() self.requests[request_id] = (parent_req, request_fut) - # Yield control to allow the run() loop to process the scheduled request - await asyncio.sleep(0) - - try: - generations = await request_fut - return generations - except asyncio.exceptions.CancelledError: - self.logger.debug(f"Request {request_id} was cancelled") - return None + return await request_fut # Abstracted to match vllm # https://github.com/vllm-project/vllm/blob/0e3bb543f064eb416bca4f6f3013efa3830b12f7/vllm/v1/engine/core.py#L419 @@ -338,77 +331,16 @@ async def run(self): @endpoint async def update_weights(self): - """Update the policy weights and schedule processing of queued requests.""" - queued_count = len(self._request_queue) - self.logger.debug( - f"Starting weight update (v{self.weights_version} -> v{self.weights_version + 1})" - ) - if queued_count > 0: - self.logger.debug( - f"Will process {queued_count} queued requests after update" - ) - - self._updating_weights = True - - # Cancel all current requests and wait for them to finish - pending_futures = [] - for request_id, (parent_req, fut) in list(self.requests.items()): - if not fut.done(): - fut.cancel("Received weight update, cancelling request") - pending_futures.append(fut) - - # Wait for all cancelled requests to finish with a timeout - if pending_futures: - self.logger.debug(f"Cancelling {len(pending_futures)} pending requests") - try: - await asyncio.wait_for( - asyncio.gather(*pending_futures, return_exceptions=True), - timeout=5.0, - ) - except asyncio.TimeoutError: - self.logger.warning("Some requests did not cancel within timeout") - - self.requests.clear() - - try: - await self.policy_worker.update.call(version=self.weights_version) - self.weights_version += 1 - self.logger.info(f"Weight update completed (now v{self.weights_version})") - except Exception as e: - self.logger.error(f"Weight update failed: {e}") - self._updating_weights = False - raise - - self._updating_weights = False - - # Schedule queue processing as a separate task to avoid blocking the endpoint - if self._request_queue: - task = asyncio.create_task(self._process_queued_requests()) - task.add_done_callback(self._queue_processing_callback) - - async def _process_queued_requests(self): - """Process all queued requests after weight update completes.""" - queued_requests = self._request_queue.copy() - self._request_queue.clear() - - for i, (prompt, priority, future) in enumerate(queued_requests): - try: - # Use the internal method directly to avoid the updating weights check - result = await self._generate(prompt, priority) - future.set_result(result) - except Exception as e: - self.logger.error(f"Error processing queued request {i+1}: {e}") - future.set_exception(e) - - def _queue_processing_callback(self, task: asyncio.Task): - """Callback to handle completion/errors of queue processing task.""" - try: - if task.exception(): - self.logger.error(f"Queue processing task failed: {task.exception()}") - else: - self.logger.debug("Queue processing task completed successfully") - except Exception as e: - self.logger.error(f"Error in queue processing callback: {e}") + # TODO: If generating long sequences, this might be long and will block policy weight updates + curr_requests = [fut for _, fut in self.requests.values()] + if curr_requests: + self.logger.debug(f"Waiting for {len(curr_requests)} pending requests") + await asyncio.gather(*curr_requests) + + self.logger.debug(f"Starting weight update on {self.__class__.__name__}") + await self.policy_worker.update.call(version=self.weights_version) + self.weights_version += 1 + self.logger.info(f"Weight update completed (now v{self.weights_version})") @endpoint async def _get_model_params(self) -> dict[str, torch.Tensor]: From 7825255d19c53a317af06da9bc389b095010d2ba Mon Sep 17 00:00:00 2001 From: joecummings Date: Thu, 11 Sep 2025 08:04:45 -0700 Subject: [PATCH 21/31] Make sd conversion happen on push --- apps/grpo/main.py | 45 +++++++++++++++++++++++++++++++++++--- src/forge/actors/policy.py | 16 +++----------- 2 files changed, 45 insertions(+), 16 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 31cbc9be5..a7632207e 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -166,6 +166,44 @@ def setup(self): self.logger.info(f"Trainer model initialized on {self.device}") + def _qwen3_hf_to_vllm(self, saved_sd): + """Convert transformers state dict to vLLM format.""" + load_sd = {} + num_layers = 28 # For Qwen3-1.7B + + # Copy over directly mapped keys + for k in saved_sd: + if any( + x in k + for x in [ + "down_proj", + "input_layernorm", + "post_attention_layernorm", + "o_proj", + "norm.weight", + "embed_tokens.weight", + "lm_head.weight", + ] + ): + load_sd[k] = saved_sd[k] + + # Fuse QKV and gate_up_proj + for i in range(num_layers): + prefix = f"model.layers.{i}." + + # QKV fusion + q = saved_sd[prefix + "self_attn.q_proj.weight"] + k = saved_sd[prefix + "self_attn.k_proj.weight"] + v = saved_sd[prefix + "self_attn.v_proj.weight"] + load_sd[prefix + "self_attn.qkv_proj.weight"] = torch.cat([q, k, v], dim=0) + + # MLP gate_up_proj fusion + gate = saved_sd[prefix + "mlp.gate_proj.weight"] + up = saved_sd[prefix + "mlp.up_proj.weight"] + load_sd[prefix + "mlp.gate_up_proj.weight"] = torch.cat([gate, up], dim=0) + + return load_sd + @endpoint async def train_step(self, batch: list[Episode]): pad_id = batch[0].pad_id @@ -204,7 +242,8 @@ async def push_weights(self, version: int): start_time = time.time() assert self.store is not None, "Store must be provided to save weights" key = f"{self.state_dict_key}{DELIM}{version}" # Use version as unique id - await push_state_dict(self.store, self.model.state_dict(), key) + new_sd = self._qwen3_hf_to_vllm(self.model.state_dict()) + await push_state_dict(self.store, new_sd, key) end_time = time.time() self.logger.debug( f"Pushed weights to {key} in {end_time - start_time:.2f} seconds" @@ -460,9 +499,9 @@ async def continuous_training(): loss = await trainer.train_step.choose(batch) training_step += 1 mlogger.log("loss/training_step", loss, training_step) - await trainer.push_weights.choose(policy_version) + await trainer.push_weights.call(policy_version) policy_version += 1 - await policy.update_weights.choose() + await policy.update_weights.call() print("Starting GRPO training loops...") # TODO: Start multiple rollouts once all serivces support it diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index b1966caed..c186f0849 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -14,7 +14,7 @@ import torch from monarch.actor import current_rank, endpoint, ProcMesh from torchstore import MultiProcessStore -from torchstore._state_dict_utils import DELIM, get_state_dict +from torchstore._state_dict_utils import DELIM from vllm.engine.arg_utils import EngineArgs from vllm.entrypoints.utils import _validate_truncation_size @@ -425,9 +425,6 @@ async def _load_tensor_parallel_state_dict( """ Load full state dict from torchstore into tensor parallel model with deterministic sharding. """ - - updated_count = 0 - # setting explictly to llama3 for now as its our only use case sharding = VLLMSharding(self.tensor_parallel_size, self.rank) for param_name in current_state_dict.keys(): @@ -444,23 +441,16 @@ async def _load_tensor_parallel_state_dict( current_tensor, ) - updated_count += 1 - @endpoint async def update(self, version: int): """Update model weights by reading state dict from torchstore""" if self.torchstore is None: raise Exception("No torchstore configured, skipping model update") - key = f"{self.state_dict_key}{DELIM}{version}" model = self.worker.model_runner.model + current_state_dict = model.state_dict() start = time.time() - new_state_dict = await get_state_dict( - self.torchstore, f"{self.state_dict_key}{DELIM}{version}" - ) - # We use the load_weights method from vLLM b/c they perform custom mapping like - # fusing QKV linear layers - model.load_weights(list(new_state_dict.items())) + await self._load_tensor_parallel_state_dict(current_state_dict, version) self.logger.debug( f"Loaded state dict from {key} in {time.time() - start} seconds" ) From b511fe343486a934fa0d1ad7ac3c81d8fc8b628a Mon Sep 17 00:00:00 2001 From: joecummings Date: Thu, 11 Sep 2025 08:17:50 -0700 Subject: [PATCH 22/31] Sum over train_step valuemesh --- apps/grpo/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index a7632207e..1aa04b955 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -496,7 +496,7 @@ async def continuous_training(): if batch is None: await asyncio.sleep(0.1) else: - loss = await trainer.train_step.choose(batch) + loss = sum(await trainer.train_step.call(batch)) training_step += 1 mlogger.log("loss/training_step", loss, training_step) await trainer.push_weights.call(policy_version) From 1a6d6dfe7aa9059da6bd15c290e9fc070d4042d2 Mon Sep 17 00:00:00 2001 From: joecummings Date: Thu, 11 Sep 2025 09:41:18 -0700 Subject: [PATCH 23/31] Update config --- apps/grpo/main.py | 2 +- apps/grpo/qwen3_1_7b.yaml | 16 +++++++++------- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 7afc08c0c..aef3cac13 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -366,7 +366,7 @@ async def sample(self) -> dict[str, str] | None: @endpoint async def pad_token(self): - return self.tokenizer.pad_token_id + return self._tokenizer.pad_token_id async def main(cfg: DictConfig): diff --git a/apps/grpo/qwen3_1_7b.yaml b/apps/grpo/qwen3_1_7b.yaml index 8ba96a096..48d163f64 100644 --- a/apps/grpo/qwen3_1_7b.yaml +++ b/apps/grpo/qwen3_1_7b.yaml @@ -1,11 +1,11 @@ # GRPO Training Configuration # Global configuration -group_size: 4 -batch_size: 4 +group_size: 8 +batch_size: 16 max_req_tokens: 512 -max_res_tokens: 128 -model: "Qwen/Qwen3-1.7B-Base" +max_res_tokens: 512 +model: "Qwen/Qwen3-1.7B" # Dataset configuration dataset: @@ -13,6 +13,7 @@ dataset: revision: "main" data_split: "train" streaming: true + tokenizer: ${model} service: procs_per_replica: 1 num_replicas: 1 @@ -24,10 +25,10 @@ policy: model: ${model} tensor_parallel_size: 1 pipeline_parallel_size: 1 - enforce_eager: true + enforce_eager: false sampling_config: - n: 4 - max_tokens: 128 + n: ${group_size} + max_tokens: ${max_res_tokens} temperature: 1.0 top_p: 1.0 service: @@ -37,6 +38,7 @@ policy: # Trainer configuration trainer: + model_name: ${model} learning_rate: 1e-5 service: procs_per_replica: 1 From e31f8152cdee99103d0e4485e8eff1bec188caf9 Mon Sep 17 00:00:00 2001 From: joecummings Date: Thu, 11 Sep 2025 12:36:54 -0700 Subject: [PATCH 24/31] Loss updates --- apps/grpo/main.py | 39 ++++++++++++++++++-------------------- src/forge/actors/policy.py | 13 ++++++------- 2 files changed, 24 insertions(+), 28 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index aef3cac13..f8b25faf2 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -49,7 +49,8 @@ def compute_logprobs( class SimpleGRPOLoss(nn.Module): """Simplified GRPO Loss for simplified single step updates - Copied from https://github.com/pytorch/torchtune/blob/main/torchtune/dev/grpo/loss.py. + Inspired by the Hugging Face TRL implementation: + https://github.com/huggingface/trl/blob/417915a3e4d3e3bc8d7b196594308b8eabf928be/trl/trainer/grpo_trainer.py#L1624. """ def __init__(self, beta: float = 0.1): @@ -57,16 +58,12 @@ def __init__(self, beta: float = 0.1): self.beta = beta def forward(self, logprobs, ref_logprobs, advantages, padding_mask): - per_token_kl = ( - torch.exp(ref_logprobs.detach() - logprobs) - - (ref_logprobs.detach() - logprobs) - - 1 - ) + kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1 per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages - per_token_loss = -(per_token_policy_loss - self.beta * per_token_kl) + per_token_loss = -(per_token_policy_loss - self.beta * kl) loss = ( - (per_token_loss * padding_mask).sum(dim=1) - / (padding_mask.sum(dim=1) + 1e-8) + ((per_token_loss * padding_mask).sum(dim=1)) + / (padding_mask.sum(dim=1).clamp(min=1.0)) ).mean() return loss @@ -211,21 +208,21 @@ def _qwen3_hf_to_vllm(self, saved_sd): return load_sd @endpoint - async def train_step(self, batch: list[Episode]): - batch = batch[self.dp_rank] - pad_id = batch[0].pad_id + async def train_step(self, batch: list[list[Episode]]): + microbatch = batch[self.dp_rank] + pad_id = microbatch[0].pad_id # prepare batch - request = [e.request_tensor for e in batch] + request = [e.request_tensor for e in microbatch] request = torch.stack(request).to(self.device) # [b x s] - response = [e.response_tensor for e in batch] + response = [e.response_tensor for e in microbatch] response = torch.stack(response).to(self.device) # [b x s] - ref_logprobs = [e.ref_logprobs for e in batch] + ref_logprobs = [e.ref_logprobs for e in microbatch] ref_logprobs = torch.stack(ref_logprobs).to(self.device).squeeze() # [b x s] - advantages = [e.advantage for e in batch] + advantages = [e.advantage for e in microbatch] advantages = torch.tensor(advantages).to(self.device).unsqueeze(-1) # [b x 1] del batch @@ -522,10 +519,10 @@ async def continuous_training(): ) -@parse -def recipe_main(cfg: DictConfig) -> None: - asyncio.run(main(cfg)) +if __name__ == "__main__": + @parse + def _main(cfg): + asyncio.run(main(cfg)) -if __name__ == "__main__": - recipe_main() + _main() # @parse grabs the cfg from CLI diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 38f418aac..01062f609 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -11,15 +11,8 @@ from collections.abc import Mapping from copy import copy from dataclasses import asdict, dataclass, field, fields -from typing import Dict, List import torch - -from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh - -from forge.data.sharding import VLLMSharding -from forge.interfaces import Policy as PolicyInterface -from forge.types import ProcessConfig from monarch.actor import current_rank, endpoint, ProcMesh from torchstore import MultiProcessStore from torchstore._state_dict_utils import DELIM @@ -44,6 +37,12 @@ from vllm.v1.structured_output import StructuredOutputManager from vllm.worker.worker_base import WorkerWrapperBase +from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh + +from forge.data.sharding import VLLMSharding +from forge.interfaces import Policy as PolicyInterface +from forge.types import ProcessConfig + @dataclass class SamplingConfig: From 55c32be78b24e430c79d9c8ed444979847257c7b Mon Sep 17 00:00:00 2001 From: joecummings Date: Thu, 11 Sep 2025 12:47:22 -0700 Subject: [PATCH 25/31] Updated rewards (just played around a bit) --- src/forge/data/rewards.py | 99 ++++++++++++++++++++------------------- 1 file changed, 52 insertions(+), 47 deletions(-) diff --git a/src/forge/data/rewards.py b/src/forge/data/rewards.py index 644c69d1b..2982536bd 100644 --- a/src/forge/data/rewards.py +++ b/src/forge/data/rewards.py @@ -17,62 +17,67 @@ def __init__(self, tolerance: float = 1e-6, partial_credit: float = 0.1): self.tolerance = tolerance self.partial_credit = partial_credit - def _to_float(self, text) -> Optional[float]: - """Safely parse a string into a float, or return None if invalid.""" - if text is None: - return None - try: - return float(str(text).strip()) - except (ValueError, TypeError): - return None - - def _extract_number(self, text: str) -> Optional[float]: - """Try to extract a numeric answer from text.""" - number_pattern = r"([+-]?\d+(?:\.\d+)?(?:e[+-]?\d+)?)" - patterns = [ - r"####\s*" + number_pattern, - r"(?:the\s+)?answer\s+is\s*" + number_pattern, - r"(?:answer:|result:)\s*" + number_pattern, - r"\$" + number_pattern, # currency - number_pattern, # fallback - r"=\s*" + number_pattern + r"\s*(?:\.|$)", - r"\b" + number_pattern + r"\s*(?:\.|$)", - ] - text = text.lower().strip() - for pattern in patterns: - matches = re.findall(pattern, text) - if matches: - return self._to_float(matches[-1]) - return None - def __call__(self, prompt: str, response: str, target: str) -> float: """Compute math correctness reward.""" - # Parse expected - expected_answer = self._to_float(target) + target_number = self._to_float(target) + if target_number is None: + return 0.0 - # Parse response - model_answer = self._extract_number(response) + # Look for answer in tags + answer_match = re.search(r"(.*?)", response, re.DOTALL) - # Scoring - if expected_answer is None or model_answer is None: - return self.partial_credit # Partial credit for attempting + if answer_match: + model_answer = self._to_float(answer_match.group(1).strip()) + if ( + model_answer is not None + and abs(target_number - model_answer) < self.tolerance + ): + return 1.0 # Correct answer - if abs(expected_answer - model_answer) < self.tolerance: - return 1.0 # Correct answer - return 0.0 # Incorrect answer + # Check for partial credit: target number appears elsewhere in response + response_without_answer_tags = re.sub( + r".*?", "", response, flags=re.DOTALL + ) + # Convert to int if it's a whole number to avoid "117.0" vs "117" mismatch + target_str = ( + str(int(target_number)) + if target_number.is_integer() + else str(target_number) + ) + if target_str in response_without_answer_tags: + return self.partial_credit + + return 0.0 # No match + + def _to_float(self, text: str) -> float | None: + """Convert text to float, return None if invalid.""" + try: + # Remove common non-numeric characters like $, commas, etc. + cleaned_text = re.sub(r"[$,\s]", "", text.strip()) + return float(cleaned_text) + except (ValueError, AttributeError): + return None class ThinkingReward(Reward): """Reward class for evaluating use of tags in reasoning.""" - def __init__(self, reward_value: float = 0.5): - self.reward_value = reward_value + def __init__(self, partial_reward: float = 0.2, full_reward: float = 1.0): + self.partial_reward = partial_reward + self.full_reward = full_reward + self._THINK_BLOCK_RE = re.compile( + r"<\s*think\s*>(.*?)<\s*/\s*think\s*>", re.IGNORECASE | re.DOTALL + ) + self._THINK_TAG_ATTEMPT_RE = re.compile(r"<\s*/?\s*think\s*>", re.IGNORECASE) - def __call__( - self, prompt: str, response: str, target: Optional[str] = None - ) -> float: - """Check if response contains ... tags.""" - resp = response.lower() - if "" in resp and "" in resp: - return self.reward_value + def __call__(self, prompt: str, response: str, target: str | None = None) -> float: + matches = self._THINK_BLOCK_RE.findall(response or "") + has_well_formed = any(len(re.sub(r"\s+", "", m)) >= 1 for m in matches) + has_attempt = bool(self._THINK_TAG_ATTEMPT_RE.search(response or "")) or bool( + matches + ) + if has_well_formed: + return self.full_reward + elif has_attempt: + return self.partial_reward return 0.0 From b74a47c5e10a6dfeb19428b19b0b0f264deab446 Mon Sep 17 00:00:00 2001 From: joecummings Date: Thu, 11 Sep 2025 13:01:16 -0700 Subject: [PATCH 26/31] Update rewards --- apps/grpo/qwen3_1_7b.yaml | 2 +- src/forge/data/rewards.py | 11 +- tests/unit_tests/rl/test_math_reward.py | 224 ++++++++++---------- tests/unit_tests/rl/test_thinking_reward.py | 190 +++++++++++------ 4 files changed, 242 insertions(+), 185 deletions(-) diff --git a/apps/grpo/qwen3_1_7b.yaml b/apps/grpo/qwen3_1_7b.yaml index 48d163f64..d6e8cec11 100644 --- a/apps/grpo/qwen3_1_7b.yaml +++ b/apps/grpo/qwen3_1_7b.yaml @@ -2,7 +2,7 @@ # Global configuration group_size: 8 -batch_size: 16 +batch_size: 8 max_req_tokens: 512 max_res_tokens: 512 model: "Qwen/Qwen3-1.7B" diff --git a/src/forge/data/rewards.py b/src/forge/data/rewards.py index 2982536bd..29a86fc3a 100644 --- a/src/forge/data/rewards.py +++ b/src/forge/data/rewards.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. import re -from typing import Optional from forge.interfaces import Reward @@ -71,11 +70,13 @@ def __init__(self, partial_reward: float = 0.2, full_reward: float = 1.0): self._THINK_TAG_ATTEMPT_RE = re.compile(r"<\s*/?\s*think\s*>", re.IGNORECASE) def __call__(self, prompt: str, response: str, target: str | None = None) -> float: - matches = self._THINK_BLOCK_RE.findall(response or "") + """Compute thinking reward.""" + if not response: + return 0.0 + + matches = self._THINK_BLOCK_RE.findall(response) has_well_formed = any(len(re.sub(r"\s+", "", m)) >= 1 for m in matches) - has_attempt = bool(self._THINK_TAG_ATTEMPT_RE.search(response or "")) or bool( - matches - ) + has_attempt = bool(self._THINK_TAG_ATTEMPT_RE.search(response)) or bool(matches) if has_well_formed: return self.full_reward elif has_attempt: diff --git a/tests/unit_tests/rl/test_math_reward.py b/tests/unit_tests/rl/test_math_reward.py index 2f3521b4d..7e31a694f 100644 --- a/tests/unit_tests/rl/test_math_reward.py +++ b/tests/unit_tests/rl/test_math_reward.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. import unittest -from unittest import mock from forge.data.rewards import MathReward @@ -36,6 +35,13 @@ def test_to_float_valid_numbers(self): self.assertEqual(self.reward._to_float("0"), 0.0) self.assertEqual(self.reward._to_float(" 123.45 "), 123.45) + def test_to_float_with_currency_and_formatting(self): + """Test _to_float with currency symbols and commas.""" + self.assertEqual(self.reward._to_float("$42"), 42.0) + self.assertEqual(self.reward._to_float("$1,000"), 1000.0) + self.assertEqual(self.reward._to_float("1,234.56"), 1234.56) + self.assertEqual(self.reward._to_float("$ 42.50 "), 42.5) + def test_to_float_invalid_inputs(self): """Test _to_float with invalid inputs.""" self.assertIsNone(self.reward._to_float("abc")) @@ -48,154 +54,140 @@ def test_to_float_edge_cases(self): """Test _to_float with edge cases.""" self.assertEqual(self.reward._to_float("1e6"), 1000000.0) self.assertEqual(self.reward._to_float("-1.5e-3"), -0.0015) - self.assertEqual(self.reward._to_float("inf"), float("inf")) - self.assertEqual(self.reward._to_float("-inf"), float("-inf")) - - def test_extract_number_gsm8k_format(self): - """Test _extract_number with GSM8K style format.""" - self.assertEqual(self.reward._extract_number("#### 42"), 42.0) - self.assertEqual(self.reward._extract_number("#### -3.14"), -3.14) - self.assertEqual(self.reward._extract_number("Some text #### 123.45"), 123.45) - - def test_extract_number_answer_patterns(self): - """Test _extract_number with various answer patterns.""" - self.assertEqual(self.reward._extract_number("The answer is 42"), 42.0) - self.assertEqual(self.reward._extract_number("answer is 3.14"), 3.14) - self.assertEqual(self.reward._extract_number("Answer: 123"), 123.0) - self.assertEqual(self.reward._extract_number("Result: -5.5"), -5.5) - - def test_extract_number_equals_pattern(self): - """Test _extract_number with equals sign patterns.""" - self.assertEqual(self.reward._extract_number("x = 42."), 42.0) - self.assertEqual(self.reward._extract_number("The result = 3.14"), 3.14) - self.assertEqual(self.reward._extract_number("calculation = -7.5."), -7.5) - - def test_extract_number_end_of_text(self): - """Test _extract_number with numbers at end of text.""" - self.assertEqual(self.reward._extract_number("The final result is 42."), 42.0) - self.assertEqual(self.reward._extract_number("We get 3.14"), 3.14) - self.assertEqual(self.reward._extract_number("Answer: -5.5."), -5.5) - - def test_extract_number_fallback_pattern(self): - """Test _extract_number with fallback pattern (any number).""" - self.assertEqual(self.reward._extract_number("There are 42 items"), 42.0) - self.assertEqual(self.reward._extract_number("Cost is $3.14 per item"), 3.14) - self.assertEqual(self.reward._extract_number("Temperature: -5.5 degrees"), -5.5) - - def test_extract_number_multiple_matches(self): - """Test _extract_number returns the last match when multiple numbers exist.""" - # Should return the last match from the pattern - self.assertEqual( - self.reward._extract_number("First 10, then 20, finally 30"), 30.0 - ) - self.assertEqual( - self.reward._extract_number("#### 5 but actually #### 10"), 10.0 - ) - def test_extract_number_no_match(self): - """Test _extract_number when no numbers are found.""" - self.assertIsNone(self.reward._extract_number("No numbers here")) - self.assertIsNone(self.reward._extract_number("")) - self.assertIsNone(self.reward._extract_number("Just text")) + def test_call_correct_answer_in_tags(self): + """Test __call__ with correct answers in tags.""" + self.assertEqual(self.reward("prompt", "42", "42"), 1.0) + self.assertEqual(self.reward("prompt", "3.14", "3.14"), 1.0) + self.assertEqual(self.reward("prompt", "-5.5", "-5.5"), 1.0) - def test_extract_number_case_insensitive(self): - """Test _extract_number is case insensitive.""" - self.assertEqual(self.reward._extract_number("THE ANSWER IS 42"), 42.0) - self.assertEqual(self.reward._extract_number("Answer: 3.14"), 3.14) - self.assertEqual(self.reward._extract_number("RESULT: 123"), 123.0) + def test_call_answer_tags_with_whitespace(self): + """Test __call__ with answer tags containing whitespace.""" + self.assertEqual(self.reward("prompt", " 42 ", "42"), 1.0) + self.assertEqual( + self.reward("prompt", "\n3.14\n", "3.14"), 1.0 + ) - def test_call_correct_answer(self): - """Test __call__ with correct answers.""" - self.assertEqual(self.reward("prompt", "The answer is 42", "42"), 1.0) - self.assertEqual(self.reward("prompt", "#### 3.14", "3.14"), 1.0) - self.assertEqual(self.reward("prompt", "Result: -5.5", "-5.5"), 1.0) + def test_call_answer_tags_with_complex_content(self): + """Test __call__ with complex content in answer tags.""" + response = """ + Let me solve this step by step: + First, I calculate 2 + 3 = 5 + Then, I multiply by 4: 5 * 4 = 20 + Finally, I subtract 8: 20 - 8 = 12 + 12 + """ + self.assertEqual(self.reward("prompt", response, "12"), 1.0) def test_call_within_tolerance(self): """Test __call__ with answers within tolerance.""" # Default tolerance is 1e-6 - self.assertEqual(self.reward("prompt", "42.0000001", "42"), 1.0) - self.assertEqual(self.reward("prompt", "3.1400001", "3.14"), 1.0) - - # Custom tolerance - self.assertEqual(self.custom_reward("prompt", "42.0001", "42"), 1.0) - self.assertEqual(self.custom_reward("prompt", "3.141", "3.14"), 1.0) - - def test_call_outside_tolerance(self): - """Test __call__ with answers outside tolerance.""" - self.assertEqual(self.reward("prompt", "42.1", "42"), 0.0) - self.assertEqual(self.reward("prompt", "3.15", "3.14"), 0.0) - self.assertEqual(self.custom_reward("prompt", "42.01", "42"), 0.0) - - def test_call_invalid_target(self): - """Test __call__ with invalid target values.""" self.assertEqual( - self.reward("prompt", "42", "invalid"), self.reward.partial_credit + self.reward("prompt", "42.0000001", "42"), 1.0 ) - self.assertEqual(self.reward("prompt", "42", ""), self.reward.partial_credit) self.assertEqual( - self.reward("prompt", "42", "not a number"), self.reward.partial_credit + self.reward("prompt", "3.1400001", "3.14"), 1.0 ) - def test_call_invalid_response(self): - """Test __call__ with invalid response values.""" + # Custom tolerance self.assertEqual( - self.reward("prompt", "no number", "42"), self.reward.partial_credit + self.custom_reward("prompt", "42.0001", "42"), 1.0 ) - self.assertEqual(self.reward("prompt", "", "42"), self.reward.partial_credit) self.assertEqual( - self.reward("prompt", "just text", "42"), self.reward.partial_credit + self.custom_reward("prompt", "3.141", "3.14"), 1.0 + ) + + def test_call_outside_tolerance(self): + """Test __call__ with answers outside tolerance.""" + self.assertEqual(self.reward("prompt", "42.1", "42"), 0.0) + self.assertEqual(self.reward("prompt", "3.15", "3.14"), 0.0) + self.assertEqual( + self.custom_reward("prompt", "42.01", "42"), 0.0 ) - def test_call_both_invalid(self): - """Test __call__ with both invalid target and response.""" + def test_call_partial_credit_target_in_response(self): + """Test __call__ with partial credit when target appears in response.""" + response = "The calculation shows 42 but I put 43" + self.assertEqual(self.reward("prompt", response, "42"), 0.1) + + response = "Let me work through this: 42 + 1 = 43. 43" + self.assertEqual(self.reward("prompt", response, "42"), 0.1) + + def test_call_partial_credit_custom_value(self): + """Test __call__ with custom partial credit value.""" + response = "The calculation shows 42 but I put 43" + self.assertEqual(self.custom_reward("prompt", response, "42"), 0.2) + + def test_call_no_partial_credit_with_answer_tags(self): + """Test __call__ doesn't give partial credit if target is only in answer tags.""" + response = "Let me solve this. 42" + # Target 100 is not elsewhere in response, so no partial credit + self.assertEqual(self.reward("prompt", response, "100"), 0.0) + + def test_call_integer_target_formatting(self): + """Test __call__ with integer targets formatted correctly.""" + # Integer targets should be formatted without decimal point + response = "I calculated and got 117 as the answer. 118" + self.assertEqual(self.reward("prompt", response, "117"), 0.1) + + # Should work with 117.0 in target too + self.assertEqual(self.reward("prompt", response, "117.0"), 0.1) + + def test_call_float_target_formatting(self): + """Test __call__ with float targets.""" + response = "I calculated and got 3.14 as the answer. 3.15" + self.assertEqual(self.reward("prompt", response, "3.14"), 0.1) + + def test_call_invalid_target(self): + """Test __call__ with invalid target values.""" + self.assertEqual(self.reward("prompt", "42", "invalid"), 0.0) + self.assertEqual(self.reward("prompt", "42", ""), 0.0) self.assertEqual( - self.reward("prompt", "no number", "invalid"), self.reward.partial_credit + self.reward("prompt", "42", "not a number"), 0.0 ) - self.assertEqual(self.reward("prompt", "", ""), self.reward.partial_credit) - def test_call_custom_partial_credit(self): - """Test __call__ uses custom partial credit value.""" - self.assertEqual(self.custom_reward("prompt", "no number", "42"), 0.2) - self.assertEqual(self.custom_reward("prompt", "42", "invalid"), 0.2) + def test_call_no_answer_tags(self): + """Test __call__ with response that has no answer tags.""" + # Should still check for partial credit + self.assertEqual(self.reward("prompt", "The answer is 42", "42"), 0.1) + self.assertEqual(self.reward("prompt", "No matching number", "42"), 0.0) + + def test_call_invalid_answer_in_tags(self): + """Test __call__ with invalid answer in tags.""" + response = "not a number but 42 is correct" + self.assertEqual(self.reward("prompt", response, "42"), 0.1) def test_call_zero_values(self): """Test __call__ with zero values.""" - self.assertEqual(self.reward("prompt", "0", "0"), 1.0) - self.assertEqual(self.reward("prompt", "The answer is 0", "0.0"), 1.0) + self.assertEqual(self.reward("prompt", "0", "0"), 1.0) + self.assertEqual(self.reward("prompt", "0.0", "0"), 1.0) def test_call_negative_values(self): """Test __call__ with negative values.""" - self.assertEqual(self.reward("prompt", "-42", "-42"), 1.0) - self.assertEqual(self.reward("prompt", "#### -3.14", "-3.14"), 1.0) - self.assertEqual(self.reward("prompt", "-5", "-4.9"), 0.0) + self.assertEqual(self.reward("prompt", "-42", "-42"), 1.0) + self.assertEqual(self.reward("prompt", "-3.14", "-3.14"), 1.0) def test_call_large_numbers(self): """Test __call__ with large numbers.""" - self.assertEqual(self.reward("prompt", "1000000", "1000000"), 1.0) - self.assertEqual(self.reward("prompt", "1e6", "1000000"), 1.0) - self.assertEqual(self.reward("prompt", "1000001", "1000000"), 0.0) + self.assertEqual( + self.reward("prompt", "1000000", "1000000"), 1.0 + ) + self.assertEqual(self.reward("prompt", "1e6", "1000000"), 1.0) def test_call_small_numbers(self): """Test __call__ with very small numbers.""" - self.assertEqual(self.reward("prompt", "0.000001", "0.000001"), 1.0) - self.assertEqual(self.reward("prompt", "1e-6", "0.000001"), 1.0) - - def test_call_complex_response_text(self): - """Test __call__ with complex response text containing multiple elements.""" - response = """ - Let me solve this step by step: - First, I calculate 2 + 3 = 5 - Then, I multiply by 4: 5 * 4 = 20 - Finally, I subtract 8: 20 - 8 = 12 - #### 12 - """ - self.assertEqual(self.reward("prompt", response, "12"), 1.0) + self.assertEqual( + self.reward("prompt", "0.000001", "0.000001"), 1.0 + ) + self.assertEqual( + self.reward("prompt", "1e-6", "0.000001"), 1.0 + ) - def test_call_with_units_and_formatting(self): - """Test __call__ with responses containing units and formatting.""" - self.assertEqual(self.reward("prompt", "The cost is $42.50", "42.5"), 1.0) - self.assertEqual(self.reward("prompt", "Distance: 3.14 meters", "3.14"), 1.0) - self.assertEqual(self.reward("prompt", "Temperature is -5.5°C", "-5.5"), 1.0) + def test_call_multiple_answer_tags(self): + """Test __call__ with multiple answer tags (should use first one).""" + response = "First answer: 42 Second: 43" + self.assertEqual(self.reward("prompt", response, "42"), 1.0) + self.assertEqual(self.reward("prompt", response, "43"), 0.1) if __name__ == "__main__": diff --git a/tests/unit_tests/rl/test_thinking_reward.py b/tests/unit_tests/rl/test_thinking_reward.py index 592ceb896..b95823e9a 100644 --- a/tests/unit_tests/rl/test_thinking_reward.py +++ b/tests/unit_tests/rl/test_thinking_reward.py @@ -13,29 +13,36 @@ class TestThinkingReward(unittest.TestCase): def setUp(self): """Set up test fixtures before each test method.""" self.reward = ThinkingReward() - self.custom_reward = ThinkingReward(reward_value=0.8) + self.custom_reward = ThinkingReward(partial_reward=0.3, full_reward=0.9) def test_init_default_values(self): """Test ThinkingReward initialization with default values.""" reward = ThinkingReward() - self.assertEqual(reward.reward_value, 0.5) + self.assertEqual(reward.partial_reward, 0.2) + self.assertEqual(reward.full_reward, 1.0) def test_init_custom_values(self): """Test ThinkingReward initialization with custom values.""" - reward = ThinkingReward(reward_value=0.8) - self.assertEqual(reward.reward_value, 0.8) + reward = ThinkingReward(partial_reward=0.3, full_reward=0.9) + self.assertEqual(reward.partial_reward, 0.3) + self.assertEqual(reward.full_reward, 0.9) - def test_call_with_both_tags(self): - """Test __call__ with response containing both and tags.""" - response = "This is my reasoning" - result = self.reward("prompt", response) - self.assertEqual(result, 0.5) + def test_regex_patterns(self): + """Test that regex patterns are compiled correctly.""" + reward = ThinkingReward() + self.assertIsNotNone(reward._THINK_BLOCK_RE) + self.assertIsNotNone(reward._THINK_TAG_ATTEMPT_RE) + + def test_call_with_well_formed_thinking_block(self): + """Test __call__ with well-formed thinking blocks.""" + result = self.reward("prompt", "This is my reasoning") + self.assertEqual(result, 1.0) - result = self.custom_reward("prompt", response) - self.assertEqual(result, 0.8) + result = self.custom_reward("prompt", "This is my reasoning") + self.assertEqual(result, 0.9) - def test_call_with_both_tags_complex_content(self): - """Test __call__ with complex content between thinking tags.""" + def test_call_with_well_formed_thinking_block_complex_content(self): + """Test __call__ with complex content in thinking blocks.""" response = """ Let me solve this problem step by step. @@ -47,40 +54,58 @@ def test_call_with_both_tags_complex_content(self): The answer is 4. """ result = self.reward("prompt", response) - self.assertEqual(result, 0.5) + self.assertEqual(result, 1.0) + + def test_call_with_minimal_content_thinking_block(self): + """Test __call__ with minimal content that still counts as well-formed.""" + result = self.reward("prompt", "x") + self.assertEqual(result, 1.0) + + def test_call_with_empty_thinking_block(self): + """Test __call__ with empty thinking block.""" + result = self.reward("prompt", "") + self.assertEqual(result, 0.2) # Should give partial reward, not full + + def test_call_with_whitespace_only_thinking_block(self): + """Test __call__ with whitespace-only thinking block.""" + result = self.reward("prompt", " \n \t ") + self.assertEqual(result, 0.2) # Should give partial reward, not full def test_call_with_only_opening_tag(self): - """Test __call__ with response containing only tag.""" - response = "This is incomplete reasoning" - result = self.reward("prompt", response) - self.assertEqual(result, 0.0) + """Test __call__ with response containing only opening tag.""" + result = self.reward("prompt", "This is incomplete reasoning") + self.assertEqual(result, 0.2) # Should give partial reward for attempt def test_call_with_only_closing_tag(self): - """Test __call__ with response containing only tag.""" - response = "This is incomplete reasoning" - result = self.reward("prompt", response) - self.assertEqual(result, 0.0) + """Test __call__ with response containing only closing tag.""" + result = self.reward("prompt", "This is incomplete reasoning") + self.assertEqual(result, 0.2) # Should give partial reward for attempt def test_call_with_no_tags(self): """Test __call__ with response containing no thinking tags.""" - response = "This is just a regular response without any thinking tags." - result = self.reward("prompt", response) + result = self.reward( + "prompt", "This is just a regular response without any thinking tags." + ) self.assertEqual(result, 0.0) def test_call_case_insensitive(self): """Test __call__ is case insensitive for thinking tags.""" - # Mixed case tags should work - response = "This is my reasoning" - result = self.reward("prompt", response) - self.assertEqual(result, 0.5) + result = self.reward("prompt", "This is my reasoning") + self.assertEqual(result, 1.0) - response = "This is my reasoning" - result = self.reward("prompt", response) - self.assertEqual(result, 0.5) + result = self.reward("prompt", "This is my reasoning") + self.assertEqual(result, 1.0) - response = "This is my reasoning" - result = self.reward("prompt", response) - self.assertEqual(result, 0.5) + result = self.reward("prompt", "This is my reasoning") + self.assertEqual(result, 1.0) + + def test_call_with_whitespace_in_tags(self): + """Test __call__ with whitespace in thinking tags.""" + result = self.reward("prompt", "< think >This is my reasoning") + self.assertEqual(result, 1.0) + + result = self.reward("prompt", "<\tthink\n>Content") + self.assertEqual(result, 1.0) def test_call_multiple_thinking_blocks(self): """Test __call__ with multiple thinking blocks.""" @@ -90,54 +115,93 @@ def test_call_multiple_thinking_blocks(self): Second thought """ result = self.reward("prompt", response) - self.assertEqual(result, 0.5) + self.assertEqual(result, 1.0) def test_call_nested_tags(self): """Test __call__ with nested or malformed tags.""" - # Nested tags - should still work as long as both tags exist - response = "Outer inner thought" + result = self.reward( + "prompt", "Outer inner thought" + ) + self.assertEqual(result, 1.0) + + def test_call_multiline_thinking_block(self): + """Test __call__ with multiline thinking blocks.""" + response = """ + This is a multiline + thinking block with + lots of content + """ result = self.reward("prompt", response) - self.assertEqual(result, 0.5) - - def test_call_empty_thinking_block(self): - """Test __call__ with empty thinking block.""" - response = "" - result = self.reward("prompt", response) - self.assertEqual(result, 0.5) + self.assertEqual(result, 1.0) def test_call_empty_response(self): """Test __call__ with empty response.""" result = self.reward("prompt", "") self.assertEqual(result, 0.0) - def test_call_tags_with_extra_whitespace(self): - """Test __call__ with thinking tags containing extra whitespace.""" - response = "< think >This has spaces< /think >" - result = self.reward("prompt", response) - self.assertEqual(result, 0.0) # Should not match due to spaces in tags + def test_call_none_response(self): + """Test __call__ with None response.""" + result = self.reward("prompt", None) + self.assertEqual(result, 0.0) def test_call_with_target_parameter(self): """Test __call__ with target parameter (should be ignored).""" - response = "This is my reasoning" - result = self.reward("prompt", response, target="some target") - self.assertEqual(result, 0.5) + result = self.reward( + "prompt", "This is my reasoning", target="some target" + ) + self.assertEqual(result, 1.0) result = self.reward("prompt", "no tags", target="some target") self.assertEqual(result, 0.0) - def test_call_zero_reward_value(self): - """Test __call__ with zero reward value.""" - zero_reward = ThinkingReward(reward_value=0.0) - response = "This is my reasoning" - result = zero_reward("prompt", response) + result = self.reward( + "prompt", "This is my reasoning", target=None + ) + self.assertEqual(result, 1.0) + + def test_call_custom_reward_values(self): + """Test __call__ with custom reward values.""" + response_full = "This is proper reasoning" + response_partial = "" + response_none = "no thinking tags" + + # Test custom partial reward + self.assertEqual(self.custom_reward("prompt", response_full), 0.9) + self.assertEqual(self.custom_reward("prompt", response_partial), 0.3) + self.assertEqual(self.custom_reward("prompt", response_none), 0.0) + + def test_call_zero_custom_values(self): + """Test __call__ with zero custom values.""" + zero_reward = ThinkingReward(partial_reward=0.0, full_reward=0.0) + result = zero_reward("prompt", "This is my reasoning") self.assertEqual(result, 0.0) - def test_call_negative_reward_value(self): - """Test __call__ with negative reward value.""" - negative_reward = ThinkingReward(reward_value=-0.5) - response = "This is my reasoning" - result = negative_reward("prompt", response) - self.assertEqual(result, -0.5) + def test_call_negative_reward_values(self): + """Test __call__ with negative reward values.""" + negative_reward = ThinkingReward(partial_reward=-0.1, full_reward=-0.5) + + self.assertEqual( + negative_reward("prompt", "This is proper reasoning"), -0.5 + ) + self.assertEqual(negative_reward("prompt", ""), -0.1) + + def test_call_edge_case_characters(self): + """Test __call__ with edge case characters in thinking blocks.""" + result = self.reward( + "prompt", "Special chars: @#$%^&*()_+-=[]{}|;':\",./<>?`~" + ) + self.assertEqual(result, 1.0) + + def test_call_unicode_characters(self): + """Test __call__ with unicode characters in thinking blocks.""" + result = self.reward("prompt", "Unicode: αβγδε 中文 🚀") + self.assertEqual(result, 1.0) + + def test_call_very_long_thinking_block(self): + """Test __call__ with very long thinking blocks.""" + long_content = "A" * 10000 + result = self.reward("prompt", f"{long_content}") + self.assertEqual(result, 1.0) if __name__ == "__main__": From 14d635436474611590a255312266a3a9b184c09c Mon Sep 17 00:00:00 2001 From: joecummings Date: Thu, 11 Sep 2025 13:10:11 -0700 Subject: [PATCH 27/31] Fix last math reward test --- tests/unit_tests/rl/test_math_reward.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/unit_tests/rl/test_math_reward.py b/tests/unit_tests/rl/test_math_reward.py index 7e31a694f..726b1173c 100644 --- a/tests/unit_tests/rl/test_math_reward.py +++ b/tests/unit_tests/rl/test_math_reward.py @@ -187,7 +187,13 @@ def test_call_multiple_answer_tags(self): """Test __call__ with multiple answer tags (should use first one).""" response = "First answer: 42 Second: 43" self.assertEqual(self.reward("prompt", response, "42"), 1.0) - self.assertEqual(self.reward("prompt", response, "43"), 0.1) + self.assertEqual(self.reward("prompt", response, "43"), 0.0) + + # Test case where target appears outside answer tags for partial credit + response_with_partial = ( + "I think the answer is 43. 42 But 43 might be better." + ) + self.assertEqual(self.reward("prompt", response_with_partial, "43"), 0.1) if __name__ == "__main__": From 8fa44510f8d6b628a5fca75ece64336478463812 Mon Sep 17 00:00:00 2001 From: joecummings Date: Thu, 11 Sep 2025 13:58:05 -0700 Subject: [PATCH 28/31] Async by 1 --- apps/grpo/qwen3_1_7b.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/apps/grpo/qwen3_1_7b.yaml b/apps/grpo/qwen3_1_7b.yaml index d6e8cec11..6ffc5f551 100644 --- a/apps/grpo/qwen3_1_7b.yaml +++ b/apps/grpo/qwen3_1_7b.yaml @@ -2,7 +2,7 @@ # Global configuration group_size: 8 -batch_size: 8 +batch_size: 16 max_req_tokens: 512 max_res_tokens: 512 model: "Qwen/Qwen3-1.7B" @@ -48,7 +48,7 @@ trainer: # Replay buffer configuration replay_buffer: batch_size: ${batch_size} - max_policy_age: 0 + max_policy_age: 1 # Async by 1 dp_size: 1 service: procs_per_replica: 1 From bdd03a834cac68bc1d27df9fd09e58136ba6bab9 Mon Sep 17 00:00:00 2001 From: joecummings Date: Fri, 12 Sep 2025 09:38:19 -0700 Subject: [PATCH 29/31] Seg fault --- apps/grpo/main.py | 63 +++++++------------------------------ apps/grpo/qwen3_1_7b.yaml | 2 +- src/forge/actors/policy.py | 24 ++++++-------- src/forge/actors/trainer.py | 46 +++++++++++++++++++++++++++ 4 files changed, 69 insertions(+), 66 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index f8b25faf2..6fb607dde 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -14,9 +14,11 @@ import torch import torch.nn.functional as F +import torchstore as ts from datasets import load_dataset from forge.actors.policy import Policy from forge.actors.replay_buffer import ReplayBuffer +from forge.actors.trainer import _qwen3_hf_to_vllm from forge.cli.config import parse from forge.controller.actor import ForgeActor from forge.controller.service import ServiceConfig, shutdown_service, spawn_service @@ -26,8 +28,7 @@ from omegaconf import DictConfig from src.forge.data.utils import exclude_service from torch import nn -from torchstore import MultiProcessStore -from torchstore._state_dict_utils import DELIM, push_state_dict +from torchstore.state_dict_utils import DELIM, put_state_dict from transformers import AutoModelForCausalLM from vllm.transformers_utils.tokenizer import get_tokenizer @@ -144,12 +145,11 @@ class Trainer(ForgeActor): learning_rate: float = 1e-5 beta: float = 0.1 device: torch.device | None = None - store: MultiProcessStore | None = None state_dict_key: str = "model_state_dict" dp_rank: int = 0 # TODO: support data parallelism, hard code it for now @endpoint - def setup(self): + async def setup(self): if self.device is None: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -167,45 +167,9 @@ def setup(self): self.loss = SimpleGRPOLoss(self.beta) - self.logger.info(f"Trainer model initialized on {self.device}") + self.store = await ts.initialize() - def _qwen3_hf_to_vllm(self, saved_sd): - """Convert transformers state dict to vLLM format.""" - load_sd = {} - num_layers = 28 # For Qwen3-1.7B - - # Copy over directly mapped keys - for k in saved_sd: - if any( - x in k - for x in [ - "down_proj", - "input_layernorm", - "post_attention_layernorm", - "o_proj", - "norm.weight", - "embed_tokens.weight", - "lm_head.weight", - ] - ): - load_sd[k] = saved_sd[k] - - # Fuse QKV and gate_up_proj - for i in range(num_layers): - prefix = f"model.layers.{i}." - - # QKV fusion - q = saved_sd[prefix + "self_attn.q_proj.weight"] - k = saved_sd[prefix + "self_attn.k_proj.weight"] - v = saved_sd[prefix + "self_attn.v_proj.weight"] - load_sd[prefix + "self_attn.qkv_proj.weight"] = torch.cat([q, k, v], dim=0) - - # MLP gate_up_proj fusion - gate = saved_sd[prefix + "mlp.gate_proj.weight"] - up = saved_sd[prefix + "mlp.up_proj.weight"] - load_sd[prefix + "mlp.gate_up_proj.weight"] = torch.cat([gate, up], dim=0) - - return load_sd + self.logger.info(f"Trainer model initialized on {self.device}") @endpoint async def train_step(self, batch: list[list[Episode]]): @@ -238,16 +202,16 @@ async def train_step(self, batch: list[list[Episode]]): loss.backward() self.optimizer.step() - return loss.detach() + return loss.item() @endpoint async def push_weights(self, version: int): """Update policy model weights with trainer's current weights.""" start_time = time.time() - assert self.store is not None, "Store must be provided to save weights" + assert self.store is not None, "Store must be initialized to save weights" key = f"{self.state_dict_key}{DELIM}{version}" # Use version as unique id - new_sd = self._qwen3_hf_to_vllm(self.model.state_dict()) - await push_state_dict(self.store, new_sd, key) + new_sd = _qwen3_hf_to_vllm(self.model.state_dict(), num_layers=28) + await put_state_dict(self.store, new_sd, key) end_time = time.time() self.logger.debug( f"Pushed weights to {key} in {end_time - start_time:.2f} seconds" @@ -322,11 +286,11 @@ class DatasetActor(ForgeActor): revision: str = "main" data_split: str = "train" streaming: bool = True - tokenizer: str = "Qwen/Qwen3-1.7B" + model: str = "Qwen/Qwen3-1.7B" @endpoint def setup(self): - self._tokenizer = get_tokenizer(self.tokenizer) + self._tokenizer = get_tokenizer(self.model) def gsm8k_transform(sample): system_prompt = """ @@ -380,7 +344,6 @@ async def main(cfg: DictConfig): ) # ---- Setup services ---- # - store = await MultiProcessStore.create_store() ( dataloader, policy, @@ -399,13 +362,11 @@ async def main(cfg: DictConfig): ServiceConfig(**cfg.policy.service), Policy, **exclude_service(cfg.policy), - store=store, ), spawn_service( ServiceConfig(**cfg.trainer.service), Trainer, **exclude_service(cfg.trainer), - store=store, ), spawn_service( ServiceConfig(**cfg.replay_buffer.service), diff --git a/apps/grpo/qwen3_1_7b.yaml b/apps/grpo/qwen3_1_7b.yaml index 6ffc5f551..6fc60bf53 100644 --- a/apps/grpo/qwen3_1_7b.yaml +++ b/apps/grpo/qwen3_1_7b.yaml @@ -13,7 +13,7 @@ dataset: revision: "main" data_split: "train" streaming: true - tokenizer: ${model} + model: ${model} service: procs_per_replica: 1 num_replicas: 1 diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 01062f609..f16850e3b 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -13,9 +13,9 @@ from dataclasses import asdict, dataclass, field, fields import torch +import torchstore as ts from monarch.actor import current_rank, endpoint, ProcMesh -from torchstore import MultiProcessStore -from torchstore._state_dict_utils import DELIM +from torchstore.state_dict_utils import DELIM from vllm.engine.arg_utils import EngineArgs from vllm.entrypoints.utils import _validate_truncation_size @@ -107,14 +107,13 @@ class Policy(PolicyInterface): lora_request: LoRARequest | None = None tokenization_kwargs: dict = field(default_factory=dict) policy_worker: "PolicyWorker" = None - store: MultiProcessStore | None = None def __post_init__(self): self._run_task: asyncio.Task | None = None self._policy_proc: ProcMesh | None = None self._worker_procs: ProcMesh | None = None self.weights_version: int = 0 - self.running: bool = False + self.running = False if isinstance(self.engine_config, Mapping): self.engine_config = EngineConfig.from_dict(self.engine_config) if isinstance(self.sampling_config, Mapping): @@ -128,7 +127,6 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] engine_config: EngineConfig | Mapping = EngineConfig(), sampling_config: SamplingConfig | Mapping = SamplingConfig(), available_devices: str | None = None, - store: MultiProcessStore | None = None, **kwargs, ) -> "Policy": # Note - get_proc_mesh will set MASTER_ADDR, MASTER_PORT and CUDA_VISIBLE_DEVICES @@ -161,7 +159,6 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] sampling_config=sampling_config, available_devices=available_devices, policy_worker=workers, - store=store, ) policy._policy_proc = policy_proc policy._worker_procs = worker_procs @@ -189,7 +186,7 @@ async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride] async def setup(self): # Set up policy_worker assert self.policy_worker is not None, "Policy worker should not be None" - await self.policy_worker.setup.call(store=self.store) + await self.policy_worker.setup.call() self.request_id = 0 self.requests: dict[str, tuple[None | ParentRequest, asyncio.Future]] = {} @@ -343,9 +340,8 @@ async def run(self): for request_output in processed_outputs.request_outputs: if request_output.finished: - if request_output.request_id in self.requests: - _, fut = self.requests.pop(request_output.request_id) - fut.set_result(request_output) + _, fut = self.requests.pop(request_output.request_id) + fut.set_result(request_output) @endpoint async def update_weights(self): @@ -403,8 +399,8 @@ def __post_init__(self): self.vllm_args = self.vllm_args.create_engine_config(UsageContext.LLM_CLASS) @endpoint - async def setup(self, store: MultiProcessStore = None): - self.torchstore = store + async def setup(self): + self.store = await ts.initialize() # TODO: remove ["gpus"] when monarch implements a flat rank self.rank = current_rank()["gpus"] self.worker = self.setup_worker() @@ -428,7 +424,7 @@ async def _load_tensor_parallel_state_dict( # Load the full tensor from torchstore # TODO: only get the part of the tensor that is needed - stored_tensor = await self.torchstore.get( + stored_tensor = await self.store.get( f"{self.state_dict_key}{DELIM}{version}{DELIM}{param_name}" ) sharding.load_from_source_to_target( @@ -440,7 +436,7 @@ async def _load_tensor_parallel_state_dict( @endpoint async def update(self, version: int): """Update model weights by reading state dict from torchstore""" - if self.torchstore is None: + if self.store is None: raise Exception("No torchstore configured, skipping model update") key = f"{self.state_dict_key}{DELIM}{version}" model = self.worker.model_runner.model diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index 4232ca5ca..062fabe8a 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -268,3 +268,49 @@ def push_weights(self) -> None: async def cleanup(self) -> None: if self.engine.checkpointer: self.engine.checkpointer.close() + + +def _qwen3_hf_to_vllm( + sd: dict[str, torch.Tensor], num_layers: int +) -> dict[str, torch.Tensor]: + """Convert transformers state dict to vLLM format. Specifically, this fuses + QKV projection and MLP gate_up_proj layers. + + Args: + sd (dict): State dict from HF model. + num_layers (int): Number of layers in the model. + + Returns: + dict: State dict in vLLM format. + """ + load_sd = {} + + # Copy over directly mapped keys + for k in sd: + if any( + x in k + for x in [ + "down_proj", + "input_layernorm", + "post_attention_layernorm", + "o_proj", + "norm.weight", + "embed_tokens.weight", + "lm_head.weight", + ] + ): + load_sd[k] = sd[k] + + for i in range(num_layers): + prefix = f"model.layers.{i}." + # QKV fusion + q = sd[prefix + "self_attn.q_proj.weight"] + k = sd[prefix + "self_attn.k_proj.weight"] + v = sd[prefix + "self_attn.v_proj.weight"] + load_sd[prefix + "self_attn.qkv_proj.weight"] = torch.cat([q, k, v], dim=0) + # MLP gate_up_proj fusion + gate = sd[prefix + "mlp.gate_proj.weight"] + up = sd[prefix + "mlp.up_proj.weight"] + load_sd[prefix + "mlp.gate_up_proj.weight"] = torch.cat([gate, up], dim=0) + + return load_sd From 7eedc91d5f7130d3043d31de8b677904533d5e83 Mon Sep 17 00:00:00 2001 From: joecummings Date: Fri, 12 Sep 2025 10:23:02 -0700 Subject: [PATCH 30/31] Make torchstore actually work! --- apps/grpo/main.py | 10 ++++------ src/forge/actors/policy.py | 5 +---- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 6fb607dde..e7ff6b6cb 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -28,7 +28,7 @@ from omegaconf import DictConfig from src.forge.data.utils import exclude_service from torch import nn -from torchstore.state_dict_utils import DELIM, put_state_dict +from torchstore.state_dict_utils import DELIM from transformers import AutoModelForCausalLM from vllm.transformers_utils.tokenizer import get_tokenizer @@ -167,8 +167,6 @@ async def setup(self): self.loss = SimpleGRPOLoss(self.beta) - self.store = await ts.initialize() - self.logger.info(f"Trainer model initialized on {self.device}") @endpoint @@ -207,11 +205,10 @@ async def train_step(self, batch: list[list[Episode]]): @endpoint async def push_weights(self, version: int): """Update policy model weights with trainer's current weights.""" - start_time = time.time() - assert self.store is not None, "Store must be initialized to save weights" key = f"{self.state_dict_key}{DELIM}{version}" # Use version as unique id new_sd = _qwen3_hf_to_vllm(self.model.state_dict(), num_layers=28) - await put_state_dict(self.store, new_sd, key) + start_time = time.time() + await ts.put_state_dict(new_sd, key) end_time = time.time() self.logger.debug( f"Pushed weights to {key} in {end_time - start_time:.2f} seconds" @@ -344,6 +341,7 @@ async def main(cfg: DictConfig): ) # ---- Setup services ---- # + await ts.initialize() ( dataloader, policy, diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index f16850e3b..788ccfe3f 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -400,7 +400,6 @@ def __post_init__(self): @endpoint async def setup(self): - self.store = await ts.initialize() # TODO: remove ["gpus"] when monarch implements a flat rank self.rank = current_rank()["gpus"] self.worker = self.setup_worker() @@ -424,7 +423,7 @@ async def _load_tensor_parallel_state_dict( # Load the full tensor from torchstore # TODO: only get the part of the tensor that is needed - stored_tensor = await self.store.get( + stored_tensor = await ts.get( f"{self.state_dict_key}{DELIM}{version}{DELIM}{param_name}" ) sharding.load_from_source_to_target( @@ -436,8 +435,6 @@ async def _load_tensor_parallel_state_dict( @endpoint async def update(self, version: int): """Update model weights by reading state dict from torchstore""" - if self.store is None: - raise Exception("No torchstore configured, skipping model update") key = f"{self.state_dict_key}{DELIM}{version}" model = self.worker.model_runner.model current_state_dict = model.state_dict() From 4044087f5595afeada9e8414433f9cee119798d8 Mon Sep 17 00:00:00 2001 From: joecummings Date: Fri, 12 Sep 2025 13:54:52 -0700 Subject: [PATCH 31/31] Last updates --- apps/grpo/main.py | 4 ++-- src/forge/actors/policy.py | 3 --- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index e7ff6b6cb..15234f928 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -196,9 +196,9 @@ async def train_step(self, batch: list[list[Episode]]): mask = response != pad_id loss = self.loss(logprobs, ref_logprobs, advantages, mask) - self.optimizer.zero_grad() loss.backward() self.optimizer.step() + self.optimizer.zero_grad(set_to_none=True) return loss.item() @@ -447,7 +447,7 @@ async def continuous_training(): if batch is None: await asyncio.sleep(0.1) else: - loss = sum(await trainer.train_step.call(batch)) + loss = await trainer.train_step.choose(batch) training_step += 1 mlogger.log("loss/training_step", loss, training_step) await trainer.push_weights.call(policy_version) diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 788ccfe3f..fd9f11482 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -243,9 +243,6 @@ async def generate(self, prompt: str, priority: int = 0) -> RequestOutput: Returns: RequestOutput: vLLM class with the generated response. """ - return await self._generate(prompt, priority) - - async def _generate(self, prompt: str, priority: int = 0) -> RequestOutput: self.request_id += 1 % sys.maxsize request_id = str(self.request_id) # implement from a counter