-
Notifications
You must be signed in to change notification settings - Fork 17
Off-by-1 GRPO #140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Off-by-1 GRPO #140
Changes from 35 commits
e6b7692
a95a001
3ba0df6
3e32264
e4723bb
5a17c8b
52028a5
e2a3a68
2cf9d00
b85320c
f7626ce
bf31587
53c8c89
f494949
a13a1ac
833a6b6
0acbe4a
7d05aad
3c880dd
8796fa1
75447d9
2838937
3120100
8f4bda1
7825255
b511fe3
9b46a77
1a6d6df
e31f815
55c32be
b74a47c
14d6354
8fa4451
bdd03a8
7eedc91
4044087
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,16 +7,18 @@ | |
# Usage: python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml | ||
|
||
import asyncio | ||
import logging | ||
import time | ||
import uuid | ||
from dataclasses import dataclass | ||
from typing import Any, Callable, Optional | ||
from typing import Any, Callable | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
import torchstore as ts | ||
from datasets import load_dataset | ||
from forge.actors.policy import Policy | ||
from forge.actors.replay_buffer import ReplayBuffer | ||
from forge.actors.trainer import _qwen3_hf_to_vllm | ||
from forge.cli.config import parse | ||
from forge.controller.actor import ForgeActor | ||
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service | ||
|
@@ -26,12 +28,10 @@ | |
from omegaconf import DictConfig | ||
from src.forge.data.utils import exclude_service | ||
from torch import nn | ||
from torchstore.state_dict_utils import DELIM | ||
from transformers import AutoModelForCausalLM | ||
from vllm.transformers_utils.tokenizer import get_tokenizer | ||
|
||
logger = logging.getLogger(__name__) | ||
logger.setLevel(logging.DEBUG) | ||
|
||
|
||
def compute_logprobs( | ||
logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0 | ||
|
@@ -50,25 +50,21 @@ def compute_logprobs( | |
|
||
class SimpleGRPOLoss(nn.Module): | ||
"""Simplified GRPO Loss for simplified single step updates | ||
Copied from https://github.com/pytorch/torchtune/blob/main/torchtune/dev/grpo/loss.py. | ||
Inspired by the Hugging Face TRL implementation: | ||
https://github.com/huggingface/trl/blob/417915a3e4d3e3bc8d7b196594308b8eabf928be/trl/trainer/grpo_trainer.py#L1624. | ||
""" | ||
|
||
def __init__(self, epsilon=0.1, beta=0.1): | ||
def __init__(self, beta: float = 0.1): | ||
super().__init__() | ||
self.epsilon = epsilon | ||
self.beta = beta | ||
|
||
def forward(self, logprobs, ref_logprobs, advantages, padding_mask): | ||
per_token_kl = ( | ||
torch.exp(ref_logprobs.detach() - logprobs) | ||
- (ref_logprobs.detach() - logprobs) | ||
- 1 | ||
) | ||
kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1 | ||
per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages | ||
per_token_loss = -(per_token_policy_loss - self.beta * per_token_kl) | ||
per_token_loss = -(per_token_policy_loss - self.beta * kl) | ||
loss = ( | ||
(per_token_loss * padding_mask).sum(dim=1) | ||
/ (padding_mask.sum(dim=1) + 1e-8) | ||
((per_token_loss * padding_mask).sum(dim=1)) | ||
/ (padding_mask.sum(dim=1).clamp(min=1.0)) | ||
).mean() | ||
return loss | ||
|
||
|
@@ -82,14 +78,14 @@ class Episode: | |
pad_id: int | ||
request_len: int | ||
response_len: int | ||
target: Optional[Any] = None | ||
target: Any | None = None | ||
# processed data | ||
response: Optional[str] = None | ||
request_tokens: Optional[list[int]] = None | ||
response_tokens: Optional[list[int]] = None | ||
ref_logprobs: Optional[torch.Tensor] = None | ||
reward: Optional[float] = None | ||
advantage: Optional[float] = None | ||
response: str | None = None | ||
request_tokens: list[int] | None = None | ||
response_tokens: list[int] | None = None | ||
ref_logprobs: torch.Tensor | None = None | ||
reward: float | None = None | ||
advantage: float | None = None | ||
|
||
@property | ||
def request_tensor(self): | ||
|
@@ -126,7 +122,7 @@ def new_group( | |
target: Any = None, | ||
): | ||
episodes = [] | ||
for i in range(group_size): | ||
for _ in range(group_size): | ||
episodes.append( | ||
Episode( | ||
episode_id=str(uuid.uuid4()), | ||
|
@@ -148,78 +144,75 @@ class Trainer(ForgeActor): | |
model_name: str | ||
learning_rate: float = 1e-5 | ||
beta: float = 0.1 | ||
epsilon: float = 0.1 | ||
device: torch.device | None = None | ||
state_dict_key: str = "model_state_dict" | ||
dp_rank: int = 0 # TODO: support data parallelism, hard code it for now | ||
|
||
@endpoint | ||
def setup(self): | ||
# Set device | ||
async def setup(self): | ||
if self.device is None: | ||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
# Initialize model | ||
self.model = AutoModelForCausalLM.from_pretrained( | ||
self.model_name, | ||
dtype=torch.bfloat16, | ||
trust_remote_code=True, | ||
).to(self.device) | ||
self.model.train() | ||
|
||
# Initialize optimizer | ||
self.optimizer = torch.optim.AdamW( | ||
self.model.parameters(), lr=self.learning_rate | ||
) | ||
self.optimizer.zero_grad() | ||
|
||
# Initialize loss | ||
self.loss = SimpleGRPOLoss(self.epsilon, self.beta) | ||
self.loss = SimpleGRPOLoss(self.beta) | ||
|
||
self.logger.info(f"Model initialized on {self.device}") | ||
self.logger.info(f"Trainer model initialized on {self.device}") | ||
|
||
@endpoint | ||
async def train_step(self, batch: list[Episode]): | ||
batch = batch[self.dp_rank] | ||
pad_id = batch[0].pad_id | ||
async def train_step(self, batch: list[list[Episode]]): | ||
microbatch = batch[self.dp_rank] | ||
pad_id = microbatch[0].pad_id | ||
|
||
# prepare batch | ||
request = [e.request_tensor for e in batch] | ||
request = [e.request_tensor for e in microbatch] | ||
request = torch.stack(request).to(self.device) # [b x s] | ||
|
||
response = [e.response_tensor for e in batch] | ||
response = [e.response_tensor for e in microbatch] | ||
response = torch.stack(response).to(self.device) # [b x s] | ||
|
||
ref_logprobs = [e.ref_logprobs for e in batch] | ||
ref_logprobs = [e.ref_logprobs for e in microbatch] | ||
ref_logprobs = torch.stack(ref_logprobs).to(self.device).squeeze() # [b x s] | ||
|
||
advantages = [e.advantage for e in batch] | ||
advantages = [e.advantage for e in microbatch] | ||
advantages = torch.tensor(advantages).to(self.device).unsqueeze(-1) # [b x 1] | ||
del batch | ||
|
||
# compute policy logprobs | ||
input_ids = torch.cat([request, response], dim=1) | ||
mask = input_ids != pad_id | ||
logits = self.model(input_ids=input_ids, attention_mask=mask).logits | ||
logprobs = compute_logprobs(logits, response) | ||
del logits | ||
|
||
# compute loss | ||
mask = response != pad_id | ||
loss = self.loss(logprobs, ref_logprobs, advantages, mask) | ||
|
||
self.optimizer.zero_grad() | ||
loss.backward() | ||
|
||
# # Gradient clipping (optional but recommended for stability) | ||
# torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) | ||
|
||
self.optimizer.step() | ||
|
||
return {"loss": loss.item()} | ||
return loss.item() | ||
|
||
@endpoint | ||
async def push_weights(self): | ||
pass | ||
async def push_weights(self, version: int): | ||
"""Update policy model weights with trainer's current weights.""" | ||
key = f"{self.state_dict_key}{DELIM}{version}" # Use version as unique id | ||
new_sd = _qwen3_hf_to_vllm(self.model.state_dict(), num_layers=28) | ||
start_time = time.time() | ||
await ts.put_state_dict(new_sd, key) | ||
end_time = time.time() | ||
self.logger.debug( | ||
f"Pushed weights to {key} in {end_time - start_time:.2f} seconds" | ||
) | ||
|
||
|
||
@dataclass | ||
|
@@ -230,11 +223,11 @@ class RewardActor(ForgeActor): | |
|
||
@endpoint | ||
async def evaluate_response(self, prompt: str, response: str, target: str) -> float: | ||
total_reward = 0.0 | ||
total_rewards = 0.0 | ||
for reward_fn in self.reward_functions: | ||
reward = reward_fn(prompt, response, target) | ||
total_reward += reward | ||
return total_reward | ||
total_rewards += reward | ||
return total_rewards / len(self.reward_functions) | ||
|
||
|
||
class ComputeAdvantages(ForgeActor): | ||
|
@@ -243,18 +236,11 @@ class ComputeAdvantages(ForgeActor): | |
@endpoint | ||
async def compute(self, group: Group) -> list[float]: | ||
# TODO: add batch processing | ||
rewards = torch.Tensor([[e.reward for e in group.episodes]]) | ||
rewards = torch.tensor([[e.reward for e in group.episodes]]) | ||
mean = rewards.mean(1, keepdim=True) | ||
std = rewards.std(1, keepdim=True) | ||
|
||
# if std is nan, return 0s. Remove this before shipping | ||
if std.isnan().any(): | ||
advantages = torch.zeros_like(rewards) | ||
else: | ||
advantages = (rewards - mean) / (std + 1e-4) | ||
|
||
x = advantages.squeeze(0).tolist() | ||
return x | ||
advantages = (rewards - mean) / (std + 1e-4) | ||
joecummings marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return advantages.squeeze(0).tolist() | ||
|
||
|
||
class RefModel(ForgeActor): | ||
|
@@ -297,16 +283,24 @@ class DatasetActor(ForgeActor): | |
revision: str = "main" | ||
data_split: str = "train" | ||
streaming: bool = True | ||
model: str = "Qwen/Qwen3-1.7B-Base" | ||
model: str = "Qwen/Qwen3-1.7B" | ||
|
||
@endpoint | ||
def setup(self): | ||
self.tokenizer = get_tokenizer(self.model) | ||
self._tokenizer = get_tokenizer(self.model) | ||
|
||
def gsm8k_transform(sample): | ||
system_prompt = """ | ||
Put all your scratchpad work between <think> and </think> tags. | ||
Your final answer should be between <answer> and </answer> tags otherwise it will not be scored. | ||
""" | ||
request: str = sample["question"] | ||
formatted_request = self.tokenizer.apply_chat_template( | ||
[{"role": "user", "content": request}], | ||
as_chat = [ | ||
{"role": "system", "content": system_prompt}, | ||
{"role": "user", "content": request}, | ||
] | ||
formatted_request = self._tokenizer.apply_chat_template( | ||
as_chat, | ||
tokenize=False, | ||
add_generation_prompt=True, | ||
) | ||
|
@@ -330,7 +324,7 @@ async def sample(self) -> dict[str, str] | None: | |
|
||
@endpoint | ||
async def pad_token(self): | ||
return self.tokenizer.pad_token_id | ||
return self._tokenizer.pad_token_id | ||
|
||
|
||
async def main(cfg: DictConfig): | ||
|
@@ -340,15 +334,14 @@ async def main(cfg: DictConfig): | |
model = cfg.model | ||
max_req_tokens = cfg.max_req_tokens | ||
max_res_tokens = cfg.max_res_tokens | ||
|
||
# ---- Setup WandB Logger ---- # | ||
logger = get_metric_logger( | ||
mlogger = get_metric_logger( | ||
"wandb", | ||
freq=1, | ||
project="grpo-training", | ||
) | ||
|
||
# ---- Setup services ---- # | ||
await ts.initialize() | ||
( | ||
dataloader, | ||
policy, | ||
|
@@ -371,7 +364,6 @@ async def main(cfg: DictConfig): | |
spawn_service( | ||
ServiceConfig(**cfg.trainer.service), | ||
Trainer, | ||
model_name=model, | ||
**exclude_service(cfg.trainer), | ||
), | ||
spawn_service( | ||
|
@@ -407,7 +399,8 @@ async def continuous_rollouts(): | |
print("Dataloader is empty, exiting continuous rollout") | ||
return | ||
prompt, target = sample["request"], sample["target"] | ||
version = 0 # await policy.get_current_version.choose() | ||
responses = await policy.generate.choose(prompt) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We'll throw away a lot of data this way for fully on policy There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For short responses yeah definitely, if you look at the WandB logs (buffer_size/rollout), you can see that we build up a buffer of about 100 episodes and then evict the majority of them back and forth during weight updates. When we start allowing much longer generations and our models are much bigger, this won't be as big of an issue. |
||
version = await policy.get_version.choose() | ||
group = Group.new_group( | ||
group_id=rollout_count, | ||
group_size=group_size, | ||
|
@@ -419,12 +412,11 @@ async def continuous_rollouts(): | |
target=target, | ||
) | ||
|
||
responses = await policy.generate.choose(prompt) | ||
|
||
# TODO: Parallelize the following calculation | ||
for episode, response in zip(group.episodes, responses.outputs): | ||
episode.request_tokens = responses.prompt_token_ids | ||
episode.response_tokens = response.token_ids | ||
assert len(response.token_ids) <= max_res_tokens | ||
episode.response = response.text | ||
episode.ref_logprobs = await ref_model.forward.choose(episode) | ||
episode.reward = await reward_actor.evaluate_response.choose( | ||
prompt=prompt, response=response.text, target=target | ||
|
@@ -434,30 +426,33 @@ async def continuous_rollouts(): | |
episode.advantage = advantage | ||
await replay_buffer.add.choose(episode) | ||
|
||
avg_response_len = ( | ||
sum(len(e.response_tokens) for e in group.episodes) / group_size | ||
) | ||
mlogger.log("avg_response_len/rollout", avg_response_len, rollout_count) | ||
buffer_size = await replay_buffer._numel.choose() | ||
mlogger.log("buffer_size/rollout", buffer_size, rollout_count) | ||
avg_reward = sum(e.reward for e in group.episodes) / group_size | ||
mlogger.log("avg_reward/rollout", avg_reward, rollout_count) | ||
|
||
rollout_count += 1 | ||
if rollout_count % 10 == 0: | ||
avg_reward = sum(e.reward for e in group.episodes) / len(group.episodes) | ||
print( | ||
f"Generated {rollout_count} rollouts w/ average reward {avg_reward}" | ||
) | ||
logger.log("reward_per_rollout", avg_reward, rollout_count) | ||
|
||
async def continuous_training(): | ||
training_step = 0 | ||
policy_version = 0 | ||
while True: | ||
batch = await replay_buffer.sample.choose(curr_policy_version=0) | ||
batch = await replay_buffer.sample.choose( | ||
curr_policy_version=policy_version | ||
) | ||
if batch is None: | ||
await asyncio.sleep(0.1) | ||
else: | ||
training_result = await trainer.train_step.choose(batch) | ||
loss = sum(await trainer.train_step.call(batch)) | ||
|
||
training_step += 1 | ||
if training_step % 10 == 0: | ||
print(f"Completed {training_step} training steps") | ||
if training_result: | ||
loss_value = training_result.get("loss", 0.0) | ||
print(f"Latest loss: {loss_value}") | ||
logger.log("loss/training_step", loss_value, training_step) | ||
# await trainer.update_weights(policy) | ||
mlogger.log("loss/training_step", loss, training_step) | ||
await trainer.push_weights.call(policy_version) | ||
policy_version += 1 | ||
await policy.update_weights.call() | ||
|
||
print("Starting GRPO training loops...") | ||
# TODO: Start multiple rollouts once all serivces support it | ||
|
@@ -483,10 +478,10 @@ async def continuous_training(): | |
) | ||
|
||
|
||
@parse | ||
def recipe_main(cfg: DictConfig) -> None: | ||
asyncio.run(main(cfg)) | ||
if __name__ == "__main__": | ||
|
||
@parse | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you explain this change? |
||
def _main(cfg): | ||
asyncio.run(main(cfg)) | ||
|
||
if __name__ == "__main__": | ||
recipe_main() | ||
_main() # @parse grabs the cfg from CLI |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.