Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
e6b7692
first changes
pbontrager Aug 29, 2025
a95a001
core updates
pbontrager Aug 31, 2025
3ba0df6
batch update
pbontrager Sep 1, 2025
3e32264
fix typo
pbontrager Sep 2, 2025
e4723bb
Merge branch 'main' into ungroup
pbontrager Sep 2, 2025
5a17c8b
Merge branch 'main' into ungroup
pbontrager Sep 2, 2025
52028a5
missing import
pbontrager Sep 2, 2025
e2a3a68
debug merge
pbontrager Sep 2, 2025
2cf9d00
more fixes
pbontrager Sep 4, 2025
b85320c
Remove dtype warnings
joecummings Sep 4, 2025
f7626ce
Stub
joecummings Sep 4, 2025
bf31587
It runs
joecummings Sep 4, 2025
53c8c89
Add in ref
joecummings Sep 4, 2025
f494949
<Replace this line with a title. Use 1 line only, 67 chars or less>
joecummings Sep 4, 2025
a13a1ac
Pass linting?
joecummings Sep 4, 2025
833a6b6
Remove extraneous 'calculations'
joecummings Sep 4, 2025
0acbe4a
Stub out push weights
joecummings Sep 4, 2025
7d05aad
Remove tokenizer, add back in formatting
joecummings Sep 4, 2025
3c880dd
Cleanup
joecummings Sep 4, 2025
8796fa1
Working w/ weight sync
joecummings Sep 4, 2025
75447d9
stub
joecummings Sep 5, 2025
2838937
Merge remote-tracking branch 'upstream/main' into working-updates
joecummings Sep 8, 2025
3120100
Queue while updating weights
joecummings Sep 8, 2025
8f4bda1
Cleanup
joecummings Sep 10, 2025
7825255
Make sd conversion happen on push
joecummings Sep 11, 2025
b511fe3
Sum over train_step valuemesh
joecummings Sep 11, 2025
9b46a77
Merge remote-tracking branch 'upstream/main' into working-updates
joecummings Sep 11, 2025
1a6d6df
Update config
joecummings Sep 11, 2025
e31f815
Loss updates
joecummings Sep 11, 2025
55c32be
Updated rewards (just played around a bit)
joecummings Sep 11, 2025
b74a47c
Update rewards
joecummings Sep 11, 2025
14d6354
Fix last math reward test
joecummings Sep 11, 2025
8fa4451
Async by 1
joecummings Sep 11, 2025
bdd03a8
Seg fault
joecummings Sep 12, 2025
7eedc91
Make torchstore actually work!
joecummings Sep 12, 2025
4044087
Last updates
joecummings Sep 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 91 additions & 96 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,18 @@
# Usage: python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml

import asyncio
import logging
import time
import uuid
from dataclasses import dataclass
from typing import Any, Callable, Optional
from typing import Any, Callable

import torch
import torch.nn.functional as F
import torchstore as ts
from datasets import load_dataset
from forge.actors.policy import Policy
from forge.actors.replay_buffer import ReplayBuffer
from forge.actors.trainer import _qwen3_hf_to_vllm
from forge.cli.config import parse
from forge.controller.actor import ForgeActor
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service
Expand All @@ -26,12 +28,10 @@
from omegaconf import DictConfig
from src.forge.data.utils import exclude_service
from torch import nn
from torchstore.state_dict_utils import DELIM
from transformers import AutoModelForCausalLM
from vllm.transformers_utils.tokenizer import get_tokenizer

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)


def compute_logprobs(
logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0
Expand All @@ -50,25 +50,21 @@ def compute_logprobs(

class SimpleGRPOLoss(nn.Module):
"""Simplified GRPO Loss for simplified single step updates
Copied from https://github.com/pytorch/torchtune/blob/main/torchtune/dev/grpo/loss.py.
Inspired by the Hugging Face TRL implementation:
https://github.com/huggingface/trl/blob/417915a3e4d3e3bc8d7b196594308b8eabf928be/trl/trainer/grpo_trainer.py#L1624.
"""

def __init__(self, epsilon=0.1, beta=0.1):
def __init__(self, beta: float = 0.1):
super().__init__()
self.epsilon = epsilon
self.beta = beta

def forward(self, logprobs, ref_logprobs, advantages, padding_mask):
per_token_kl = (
torch.exp(ref_logprobs.detach() - logprobs)
- (ref_logprobs.detach() - logprobs)
- 1
)
kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1
per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages
per_token_loss = -(per_token_policy_loss - self.beta * per_token_kl)
per_token_loss = -(per_token_policy_loss - self.beta * kl)
loss = (
(per_token_loss * padding_mask).sum(dim=1)
/ (padding_mask.sum(dim=1) + 1e-8)
((per_token_loss * padding_mask).sum(dim=1))
/ (padding_mask.sum(dim=1).clamp(min=1.0))
).mean()
return loss

Expand All @@ -82,14 +78,14 @@ class Episode:
pad_id: int
request_len: int
response_len: int
target: Optional[Any] = None
target: Any | None = None
# processed data
response: Optional[str] = None
request_tokens: Optional[list[int]] = None
response_tokens: Optional[list[int]] = None
ref_logprobs: Optional[torch.Tensor] = None
reward: Optional[float] = None
advantage: Optional[float] = None
response: str | None = None
request_tokens: list[int] | None = None
response_tokens: list[int] | None = None
ref_logprobs: torch.Tensor | None = None
reward: float | None = None
advantage: float | None = None

@property
def request_tensor(self):
Expand Down Expand Up @@ -126,7 +122,7 @@ def new_group(
target: Any = None,
):
episodes = []
for i in range(group_size):
for _ in range(group_size):
episodes.append(
Episode(
episode_id=str(uuid.uuid4()),
Expand All @@ -148,78 +144,75 @@ class Trainer(ForgeActor):
model_name: str
learning_rate: float = 1e-5
beta: float = 0.1
epsilon: float = 0.1
device: torch.device | None = None
state_dict_key: str = "model_state_dict"
dp_rank: int = 0 # TODO: support data parallelism, hard code it for now

@endpoint
def setup(self):
# Set device
async def setup(self):
if self.device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize model
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
dtype=torch.bfloat16,
trust_remote_code=True,
).to(self.device)
self.model.train()

# Initialize optimizer
self.optimizer = torch.optim.AdamW(
self.model.parameters(), lr=self.learning_rate
)
self.optimizer.zero_grad()

# Initialize loss
self.loss = SimpleGRPOLoss(self.epsilon, self.beta)
self.loss = SimpleGRPOLoss(self.beta)

self.logger.info(f"Model initialized on {self.device}")
self.logger.info(f"Trainer model initialized on {self.device}")

@endpoint
async def train_step(self, batch: list[Episode]):
batch = batch[self.dp_rank]
pad_id = batch[0].pad_id
async def train_step(self, batch: list[list[Episode]]):
microbatch = batch[self.dp_rank]
pad_id = microbatch[0].pad_id

# prepare batch
request = [e.request_tensor for e in batch]
request = [e.request_tensor for e in microbatch]
request = torch.stack(request).to(self.device) # [b x s]

response = [e.response_tensor for e in batch]
response = [e.response_tensor for e in microbatch]
response = torch.stack(response).to(self.device) # [b x s]

ref_logprobs = [e.ref_logprobs for e in batch]
ref_logprobs = [e.ref_logprobs for e in microbatch]
ref_logprobs = torch.stack(ref_logprobs).to(self.device).squeeze() # [b x s]

advantages = [e.advantage for e in batch]
advantages = [e.advantage for e in microbatch]
advantages = torch.tensor(advantages).to(self.device).unsqueeze(-1) # [b x 1]
del batch

# compute policy logprobs
input_ids = torch.cat([request, response], dim=1)
mask = input_ids != pad_id
logits = self.model(input_ids=input_ids, attention_mask=mask).logits
logprobs = compute_logprobs(logits, response)
del logits

# compute loss
mask = response != pad_id
loss = self.loss(logprobs, ref_logprobs, advantages, mask)

self.optimizer.zero_grad()
loss.backward()

# # Gradient clipping (optional but recommended for stability)
# torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)

self.optimizer.step()
self.optimizer.zero_grad(set_to_none=True)

return {"loss": loss.item()}
return loss.item()

@endpoint
async def push_weights(self):
pass
async def push_weights(self, version: int):
"""Update policy model weights with trainer's current weights."""
key = f"{self.state_dict_key}{DELIM}{version}" # Use version as unique id
new_sd = _qwen3_hf_to_vllm(self.model.state_dict(), num_layers=28)
start_time = time.time()
await ts.put_state_dict(new_sd, key)
end_time = time.time()
self.logger.debug(
f"Pushed weights to {key} in {end_time - start_time:.2f} seconds"
)


@dataclass
Expand All @@ -230,11 +223,11 @@ class RewardActor(ForgeActor):

@endpoint
async def evaluate_response(self, prompt: str, response: str, target: str) -> float:
total_reward = 0.0
total_rewards = 0.0
for reward_fn in self.reward_functions:
reward = reward_fn(prompt, response, target)
total_reward += reward
return total_reward
total_rewards += reward
return total_rewards / len(self.reward_functions)


class ComputeAdvantages(ForgeActor):
Expand All @@ -243,18 +236,11 @@ class ComputeAdvantages(ForgeActor):
@endpoint
async def compute(self, group: Group) -> list[float]:
# TODO: add batch processing
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# TODO: add batch processing
# TODO: issues/120 add batch processing

rewards = torch.Tensor([[e.reward for e in group.episodes]])
rewards = torch.tensor([[e.reward for e in group.episodes]])
mean = rewards.mean(1, keepdim=True)
std = rewards.std(1, keepdim=True)

# if std is nan, return 0s. Remove this before shipping
if std.isnan().any():
advantages = torch.zeros_like(rewards)
else:
advantages = (rewards - mean) / (std + 1e-4)

x = advantages.squeeze(0).tolist()
return x
advantages = (rewards - mean) / (std + 1e-4)
return advantages.squeeze(0).tolist()


class RefModel(ForgeActor):
Expand Down Expand Up @@ -297,16 +283,24 @@ class DatasetActor(ForgeActor):
revision: str = "main"
data_split: str = "train"
streaming: bool = True
model: str = "Qwen/Qwen3-1.7B-Base"
model: str = "Qwen/Qwen3-1.7B"

@endpoint
def setup(self):
self.tokenizer = get_tokenizer(self.model)
self._tokenizer = get_tokenizer(self.model)

def gsm8k_transform(sample):
system_prompt = """
Put all your scratchpad work between <think> and </think> tags.
Your final answer should be between <answer> and </answer> tags otherwise it will not be scored.
"""
request: str = sample["question"]
formatted_request = self.tokenizer.apply_chat_template(
[{"role": "user", "content": request}],
as_chat = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": request},
]
formatted_request = self._tokenizer.apply_chat_template(
as_chat,
tokenize=False,
add_generation_prompt=True,
)
Expand All @@ -330,7 +324,7 @@ async def sample(self) -> dict[str, str] | None:

@endpoint
async def pad_token(self):
return self.tokenizer.pad_token_id
return self._tokenizer.pad_token_id


async def main(cfg: DictConfig):
Expand All @@ -340,15 +334,14 @@ async def main(cfg: DictConfig):
model = cfg.model
max_req_tokens = cfg.max_req_tokens
max_res_tokens = cfg.max_res_tokens

# ---- Setup WandB Logger ---- #
logger = get_metric_logger(
mlogger = get_metric_logger(
"wandb",
freq=1,
project="grpo-training",
)

# ---- Setup services ---- #
await ts.initialize()
(
dataloader,
policy,
Expand All @@ -371,7 +364,6 @@ async def main(cfg: DictConfig):
spawn_service(
ServiceConfig(**cfg.trainer.service),
Trainer,
model_name=model,
**exclude_service(cfg.trainer),
),
spawn_service(
Expand Down Expand Up @@ -407,7 +399,8 @@ async def continuous_rollouts():
print("Dataloader is empty, exiting continuous rollout")
return
prompt, target = sample["request"], sample["target"]
version = 0 # await policy.get_current_version.choose()
responses = await policy.generate.choose(prompt)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll throw away a lot of data this way for fully on policy

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For short responses yeah definitely, if you look at the WandB logs (buffer_size/rollout), you can see that we build up a buffer of about 100 episodes and then evict the majority of them back and forth during weight updates.

When we start allowing much longer generations and our models are much bigger, this won't be as big of an issue.

version = await policy.get_version.choose()
group = Group.new_group(
group_id=rollout_count,
group_size=group_size,
Expand All @@ -419,12 +412,11 @@ async def continuous_rollouts():
target=target,
)

responses = await policy.generate.choose(prompt)

# TODO: Parallelize the following calculation
for episode, response in zip(group.episodes, responses.outputs):
episode.request_tokens = responses.prompt_token_ids
episode.response_tokens = response.token_ids
assert len(response.token_ids) <= max_res_tokens
episode.response = response.text
episode.ref_logprobs = await ref_model.forward.choose(episode)
episode.reward = await reward_actor.evaluate_response.choose(
prompt=prompt, response=response.text, target=target
Expand All @@ -434,30 +426,33 @@ async def continuous_rollouts():
episode.advantage = advantage
await replay_buffer.add.choose(episode)

avg_response_len = (
sum(len(e.response_tokens) for e in group.episodes) / group_size
)
mlogger.log("avg_response_len/rollout", avg_response_len, rollout_count)
buffer_size = await replay_buffer._numel.choose()
mlogger.log("buffer_size/rollout", buffer_size, rollout_count)
avg_reward = sum(e.reward for e in group.episodes) / group_size
mlogger.log("avg_reward/rollout", avg_reward, rollout_count)

rollout_count += 1
if rollout_count % 10 == 0:
avg_reward = sum(e.reward for e in group.episodes) / len(group.episodes)
print(
f"Generated {rollout_count} rollouts w/ average reward {avg_reward}"
)
logger.log("reward_per_rollout", avg_reward, rollout_count)

async def continuous_training():
training_step = 0
policy_version = 0
while True:
batch = await replay_buffer.sample.choose(curr_policy_version=0)
batch = await replay_buffer.sample.choose(
curr_policy_version=policy_version
)
if batch is None:
await asyncio.sleep(0.1)
else:
training_result = await trainer.train_step.choose(batch)
loss = await trainer.train_step.choose(batch)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should also be a call

training_step += 1
if training_step % 10 == 0:
print(f"Completed {training_step} training steps")
if training_result:
loss_value = training_result.get("loss", 0.0)
print(f"Latest loss: {loss_value}")
logger.log("loss/training_step", loss_value, training_step)
# await trainer.update_weights(policy)
mlogger.log("loss/training_step", loss, training_step)
await trainer.push_weights.call(policy_version)
policy_version += 1
await policy.update_weights.call()

print("Starting GRPO training loops...")
# TODO: Start multiple rollouts once all serivces support it
Expand All @@ -483,10 +478,10 @@ async def continuous_training():
)


@parse
def recipe_main(cfg: DictConfig) -> None:
asyncio.run(main(cfg))
if __name__ == "__main__":

@parse
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain this change?

def _main(cfg):
asyncio.run(main(cfg))

if __name__ == "__main__":
recipe_main()
_main() # @parse grabs the cfg from CLI
Loading
Loading