Skip to content
Merged
211 changes: 79 additions & 132 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
# Usage: python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml

import asyncio
import time
import uuid
from dataclasses import dataclass
from typing import Any, Callable
Expand All @@ -17,38 +16,19 @@
import torchstore as ts
from datasets import load_dataset
from forge.actors.policy import Policy
from forge.actors.reference_model import ReferenceModel # noqa: F401
from forge.actors.reference_model import ReferenceModel
from forge.actors.replay_buffer import ReplayBuffer
from forge.actors.trainer import _qwen3_hf_to_vllm
from forge.actors.trainer import RLTrainer
from forge.cli.config import parse
from forge.controller.actor import ForgeActor
from forge.controller.provisioner import shutdown
from forge.data.rewards import MathReward, ThinkingReward
from forge.losses.grpo_loss import SimpleGRPOLoss
from forge.util.metric_logging import get_metric_logger
from monarch.actor import endpoint
from omegaconf import DictConfig
from torchstore.state_dict_utils import DELIM
from torchtitan.config.job_config import Model as TitanJobModelConfig
from transformers import AutoModelForCausalLM
from vllm.transformers_utils.tokenizer import get_tokenizer


def compute_logprobs(
logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0
) -> torch.Tensor:
context_length = logits.shape[1] - input_ids.shape[1]

# Truncate request logits and drop last
logits = logits[:, context_length - 1 : -1]

# Compute logprobs
logprobs = torch.log_softmax(logits / temperature, dim=-1)
logprobs = torch.gather(logprobs, 2, input_ids.unsqueeze(-1)).squeeze(-1)

return logprobs


@dataclass
class Episode:
# TODO: add adtional layer for multi-turn
Expand Down Expand Up @@ -117,82 +97,64 @@ def new_group(
return cls(str(group_id), episodes)


@dataclass
class Trainer(ForgeActor):
"""GRPO Trainer implementation for policy optimization."""

model_name: str
learning_rate: float = 1e-5
beta: float = 0.1
device: torch.device | None = None
state_dict_key: str = "model_state_dict"
dp_rank: int = 0 # TODO: support data parallelism, hard code it for now

@endpoint
async def setup(self):
if self.device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
dtype=torch.bfloat16,
trust_remote_code=True,
).to(self.device)
self.model.train()

self.optimizer = torch.optim.AdamW(
self.model.parameters(), lr=self.learning_rate
)
self.optimizer.zero_grad()
def collate(batches: list[list[Episode]]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not for this PR, I would prefer this conversion to be a classmethod on Episode cc @Jack-Khuu

inputs = []
targets = []
for batch in batches:
request = [e.request_tensor for e in batch]
request = torch.stack(request) # [b x s]

self.loss = SimpleGRPOLoss(self.beta)
response = [e.response_tensor for e in batch]
response = torch.stack(response) # [b x s]

self.logger.info(f"Trainer model initialized on {self.device}")
ref_logprobs = [e.ref_logprobs for e in batch]
ref_logprobs = torch.stack(ref_logprobs).squeeze() # [b x s]

@endpoint
async def train_step(self, batch: list[list[Episode]]):
microbatch = batch[self.dp_rank]
pad_id = microbatch[0].pad_id
advantages = [e.advantage for e in batch]
advantages = torch.tensor(advantages).unsqueeze(-1) # [b x 1]

# prepare batch
request = [e.request_tensor for e in microbatch]
request = torch.stack(request).to(self.device) # [b x s]

response = [e.response_tensor for e in microbatch]
response = torch.stack(response).to(self.device) # [b x s]

ref_logprobs = [e.ref_logprobs for e in microbatch]
ref_logprobs = torch.stack(ref_logprobs).to(self.device).squeeze() # [b x s]
pad_id = batch[0].pad_id
mask = response != pad_id

advantages = [e.advantage for e in microbatch]
advantages = torch.tensor(advantages).to(self.device).unsqueeze(-1) # [b x 1]
del batch
input = {"tokens": torch.cat([request, response], dim=1)}
target = {
"response": response,
"ref_logprobs": ref_logprobs,
"advantages": advantages,
"padding_mask": mask,
}
inputs.append(input)
targets.append(target)
return inputs, targets

input_ids = torch.cat([request, response], dim=1)
mask = input_ids != pad_id
logits = self.model(input_ids=input_ids, attention_mask=mask).logits
logprobs = compute_logprobs(logits, response)
del logits

mask = response != pad_id
loss = self.loss(logprobs, ref_logprobs, advantages, mask)
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad(set_to_none=True)
def compute_logprobs(
logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0
) -> torch.Tensor:
context_length = logits.shape[1] - input_ids.shape[1]
logits = logits[:, context_length - 1 : -1]
logprobs = torch.log_softmax(logits / temperature, dim=-1).to(input_ids.device)
logprobs = torch.gather(logprobs, 2, input_ids.unsqueeze(-1)).squeeze(-1)
return logprobs

return loss.item()

@endpoint
async def push_weights(self, version: int):
"""Update policy model weights with trainer's current weights."""
key = f"{self.state_dict_key}{DELIM}{version}" # Use version as unique id
new_sd = _qwen3_hf_to_vllm(self.model.state_dict(), num_layers=28)
start_time = time.time()
await ts.put_state_dict(new_sd, key)
end_time = time.time()
self.logger.debug(
f"Pushed weights to {key} in {end_time - start_time:.2f} seconds"
)
def simple_grpo_loss(
logits: torch.Tensor,
response: torch.Tensor,
ref_logprobs: torch.Tensor,
advantages: torch.Tensor,
padding_mask: torch.Tensor,
beta: float = 0.1,
) -> torch.Tensor:
logprobs = compute_logprobs(logits, response)
kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1
per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages
per_token_loss = -(per_token_policy_loss - beta * kl)
loss = (
((per_token_loss * padding_mask).sum(dim=1))
/ (padding_mask.sum(dim=1).clamp(min=1.0))
).mean()
return loss


@dataclass
Expand Down Expand Up @@ -223,38 +185,6 @@ async def compute(self, group: Group) -> list[float]:
return advantages.squeeze(0).tolist()


class RefModel(ForgeActor):
def __init__(self, model_name, device: torch.device | None = None):
super().__init__()
self.model_name = model_name

if device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = device

self.model = AutoModelForCausalLM.from_pretrained(
model_name,
dtype=torch.bfloat16,
trust_remote_code=True,
).to(self.device)
self.model.eval()

self.logger.info(f"Model initialized on {self.device}")

@endpoint
async def forward(self, episode: Episode) -> torch.Tensor:
req, res = episode.request_tensor, episode.response_tensor
input_ids = torch.cat([req, res]).to(self.device).unsqueeze(0)
mask = input_ids != episode.pad_id

with torch.inference_mode():
logits = self.model(input_ids=input_ids, attention_mask=mask).logits

input_ids = input_ids[:, len(req) :]
return compute_logprobs(logits, input_ids)


@dataclass
class DatasetActor(ForgeActor):
"""Actor wrapper for HuggingFace dataset to provide async interface."""
Expand Down Expand Up @@ -309,10 +239,7 @@ async def pad_token(self):

async def main(cfg: DictConfig):
"""Main GRPO training loop with rollout and training processes."""
titan_model = TitanJobModelConfig(name="qwen3", flavor="1.7B")
# Get parameters from config with fallbacks
group_size = cfg.group_size
model = cfg.model
max_req_tokens = cfg.max_req_tokens
max_res_tokens = cfg.max_res_tokens
mlogger = get_metric_logger(
Expand All @@ -334,17 +261,18 @@ async def main(cfg: DictConfig):
) = await asyncio.gather(
DatasetActor.options(**cfg.services.dataset).as_service(**cfg.dataset),
Policy.options(**cfg.services.policy).as_service(**cfg.policy),
Trainer.options(**cfg.services.trainer).as_service(**cfg.trainer),
RLTrainer.options(**cfg.services.trainer).as_service(
**cfg.trainer, loss=simple_grpo_loss
),
ReplayBuffer.options(**cfg.services.replay_buffer).as_service(
**cfg.replay_buffer
**cfg.replay_buffer, collate=collate
),
ComputeAdvantages.options(**cfg.services.compute_advantages).as_service(),
RefModel.options(**cfg.services.ref_model).as_service(**cfg.ref_model),
ReferenceModel.options(**cfg.services.ref_model).as_service(**cfg.ref_model),
RewardActor.options(**cfg.services.reward_actor).as_service(
reward_functions=[MathReward(), ThinkingReward()]
),
)

print("All services initialized successfully!")

# ---- Core RL loops ---- #
Expand All @@ -370,20 +298,38 @@ async def continuous_rollouts():
target=target,
)

# TODO: Parallelize the following calculation
for episode, response in zip(group.episodes, responses.outputs):
input_ids = torch.ones(
(group_size, max_req_tokens + max_req_tokens),
dtype=torch.long,
device="cuda",
)
# Populate episode info and calculate rewards
for i, (episode, response) in enumerate(
zip(group.episodes, responses.outputs)
):
episode.request_tokens = responses.prompt_token_ids
episode.response_tokens = response.token_ids
episode.response = response.text
episode.ref_logprobs = await ref_model.forward.choose(episode)
input_ids[i, :max_req_tokens] = episode.request_tensor
input_ids[i, max_req_tokens:] = episode.response_tensor
episode.reward = await reward_actor.evaluate_response.choose(
prompt=prompt, response=response.text, target=target
)

# Calculate reference logprobs
ref_logits = await ref_model.forward.choose(input_ids)
ref_logprobs = compute_logprobs(ref_logits, input_ids[:, max_req_tokens:])
for i, episode in enumerate(group.episodes):
episode.ref_logprobs = ref_logprobs[i]
del ref_logits, ref_logprobs, input_ids

# Calculate advantages and add to replay buffer
advantages = await compute_advantages.compute.choose(group)
for episode, advantage in zip(group.episodes, advantages):
episode.advantage = advantage
await replay_buffer.add.choose(episode)

# Log metrics
avg_response_len = (
sum(len(e.response_tokens) for e in group.episodes) / group_size
)
Expand All @@ -405,7 +351,8 @@ async def continuous_training():
if batch is None:
await asyncio.sleep(0.1)
else:
loss = await trainer.train_step.choose(batch)
inputs, targets = batch
loss = await trainer.train_step.choose(inputs, targets)
training_step += 1
mlogger.log("loss/training_step", loss, training_step)
await trainer.push_weights.call(policy_version)
Expand Down
Loading
Loading