Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ share/python-wheels/
.installed.cfg
*.egg
MANIFEST
.rsyncignore

# Django stuff
*.log
Expand Down
389 changes: 389 additions & 0 deletions apps/grpo/main_no_reward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,389 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# Usage: python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml

import asyncio
import uuid
from dataclasses import dataclass
from typing import Any, Callable

import torch
import torch.nn.functional as F
import torchstore as ts
from datasets import load_dataset
from forge.actors.policy import Policy
from forge.actors.replay_buffer import ReplayBuffer
from forge.actors.trainer import RLTrainer
from forge.cli.config import parse
from forge.controller.actor import ForgeActor
from forge.controller.provisioner import shutdown
from forge.data.rewards import MathReward, ThinkingReward
from forge.util.metric_logging import get_metric_logger
from monarch.actor import endpoint
from omegaconf import DictConfig
from vllm.transformers_utils.tokenizer import get_tokenizer


@dataclass
class Episode:
# TODO: add adtional layer for multi-turn
episode_id: str
request: str
policy_version: int
pad_id: int
request_len: int
response_len: int
target: Any | None = None
# processed data
response: str | None = None
request_tokens: list[int] | None = None
response_tokens: list[int] | None = None
ref_logprobs: torch.Tensor | None = None
reward: float | None = None
advantage: float | None = None

@property
def request_tensor(self):
tensor = torch.tensor(self.request_tokens, dtype=torch.long)
if tensor.shape[0] < self.request_len: # left pad
diff = self.request_len - tensor.shape[0]
tensor = F.pad(tensor, (diff, 0), value=self.pad_id)
return tensor

@property
def response_tensor(self):
tensor = torch.tensor(self.response_tokens, dtype=torch.long)
if tensor.shape[0] < self.response_len: # right pad
diff = self.response_len - tensor.shape[0]
tensor = F.pad(tensor, (0, diff), value=self.pad_id)
return tensor


@dataclass
class Group:
group_id: str
episodes: list[Episode]

@classmethod
def new_group(
cls,
group_id: int,
group_size: int,
request: str,
policy_version: int,
pad_id: int,
request_len: int,
response_len: int,
target: Any = None,
):
episodes = []
for _ in range(group_size):
episodes.append(
Episode(
episode_id=str(uuid.uuid4()),
request=request,
policy_version=policy_version,
pad_id=pad_id,
request_len=request_len,
response_len=response_len,
target=target,
)
)
return cls(str(group_id), episodes)


def collate(batches: list[list[Episode]]):
inputs = []
targets = []
for batch in batches:
request = [e.request_tensor for e in batch]
request = torch.stack(request) # [b x s]

response = [e.response_tensor for e in batch]
response = torch.stack(response) # [b x s]

# mock out the ref logprobs for now
ref_logprobs = torch.zeros(len(batch), batch[0].response_len) # [b x s]

advantages = [e.advantage for e in batch]
advantages = torch.tensor(advantages).unsqueeze(-1) # [b x 1]

pad_id = batch[0].pad_id
mask = response != pad_id

input = {"tokens": torch.cat([request, response], dim=1)}
target = {
"response": response,
"ref_logprobs": ref_logprobs,
"advantages": advantages,
"padding_mask": mask,
}
inputs.append(input)
targets.append(target)
return inputs, targets


def compute_logprobs(
logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0
) -> torch.Tensor:
context_length = logits.shape[1] - input_ids.shape[1]
logits = logits[:, context_length - 1 : -1]
logprobs = torch.log_softmax(logits / temperature, dim=-1).to(input_ids.device)
logprobs = torch.gather(logprobs, 2, input_ids.unsqueeze(-1)).squeeze(-1)
return logprobs


def simple_grpo_loss(
logits: torch.Tensor,
response: torch.Tensor,
ref_logprobs: torch.Tensor,
advantages: torch.Tensor,
padding_mask: torch.Tensor,
beta: float = 0.1,
) -> torch.Tensor:
logprobs = compute_logprobs(logits, response)
# kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1
per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages
per_token_loss = -(per_token_policy_loss)
loss = (
((per_token_loss * padding_mask).sum(dim=1))
/ (padding_mask.sum(dim=1).clamp(min=1.0))
).mean()
return loss


@dataclass
class RewardActor(ForgeActor):
"""Reward actor that uses a list of scoring functions."""

reward_functions: list[Callable]

@endpoint
async def evaluate_response(self, prompt: str, response: str, target: str) -> float:
total_rewards = 0.0
for reward_fn in self.reward_functions:
reward = reward_fn(prompt, response, target)
total_rewards += reward
return total_rewards / len(self.reward_functions)


class ComputeAdvantages(ForgeActor):
"""Compute advantages for GRPO using reward signals."""

@endpoint
async def compute(self, group: Group) -> list[float]:
# TODO: add batch processing
rewards = torch.tensor([[e.reward for e in group.episodes]])
mean = rewards.mean(1, keepdim=True)
std = rewards.std(1, keepdim=True)
advantages = (rewards - mean) / (std + 1e-4)
return advantages.squeeze(0).tolist()


@dataclass
class DatasetActor(ForgeActor):
"""Actor wrapper for HuggingFace dataset to provide async interface."""

path: str = "openai/gsm8k"
revision: str = "main"
data_split: str = "train"
streaming: bool = True
model: str = "Qwen/Qwen3-1.7B"

@endpoint
def setup(self):
self._tokenizer = get_tokenizer(self.model)

def gsm8k_transform(sample):
system_prompt = """
Put all your scratchpad work between <think> and </think> tags.
Your final answer should be between <answer> and </answer> tags otherwise it will not be scored.
"""
request: str = sample["question"]
as_chat = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": request},
]
formatted_request = self._tokenizer.apply_chat_template(
as_chat,
tokenize=False,
add_generation_prompt=True,
)
target: str = sample["answer"]
formatted_target = target.split("#### ")[1]
return {"request": formatted_request, "target": formatted_target}

ds = load_dataset(
self.path, self.revision, split=self.data_split, streaming=self.streaming
)
ds = ds.map(gsm8k_transform)
ds = ds.shuffle()
self._iterator = iter(ds)

@endpoint
async def sample(self) -> dict[str, str] | None:
try:
return next(self._iterator)
except StopIteration:
return None

@endpoint
async def pad_token(self):
return self._tokenizer.pad_token_id


async def main(cfg: DictConfig):
"""Main GRPO training loop with rollout and training processes."""
group_size = cfg.group_size
max_req_tokens = cfg.max_req_tokens
max_res_tokens = cfg.max_res_tokens
mlogger = get_metric_logger(
"wandb",
freq=1,
project="grpo-training",
)

# ---- Setup services ---- #
await ts.initialize(strategy=ts.ControllerStorageVolumes())
(
dataloader,
policy,
trainer,
replay_buffer,
compute_advantages,
# ref_model,
reward_actor,
) = await asyncio.gather(
DatasetActor.options(**cfg.services.dataset).as_service(**cfg.dataset),
Policy.options(**cfg.services.policy).as_service(**cfg.policy),
RLTrainer.options(**cfg.services.trainer).as_service(
**cfg.trainer, loss=simple_grpo_loss
),
ReplayBuffer.options(**cfg.services.replay_buffer).as_service(
**cfg.replay_buffer, collate=collate
),
ComputeAdvantages.options(**cfg.services.compute_advantages).as_service(),
# ReferenceModel.options(**cfg.services.ref_model).as_service(**cfg.ref_model),
RewardActor.options(**cfg.services.reward_actor).as_service(
reward_functions=[MathReward(), ThinkingReward()]
),
)
print("All services initialized successfully!")

# ---- Core RL loops ---- #
async def continuous_rollouts():
rollout_count = 0
pad_id = await dataloader.pad_token.choose()
while True:
sample = await dataloader.sample.choose()
if sample is None:
print("Dataloader is empty, exiting continuous rollout")
return
prompt, target = sample["request"], sample["target"]
responses = await policy.generate.choose(prompt)
# TODO: this shall be part of the responses metadata instead of a separate call
version = await policy.get_version.choose()
group = Group.new_group(
group_id=rollout_count,
group_size=group_size,
request=prompt,
policy_version=version,
pad_id=pad_id,
request_len=max_req_tokens,
response_len=max_res_tokens,
target=target,
)

input_ids = torch.ones(
(group_size, max_req_tokens + max_req_tokens),
dtype=torch.long,
device="cuda",
)
# Populate episode info and calculate rewards
for i, (episode, response) in enumerate(zip(group.episodes, responses)):
episode.request_tokens = response.prompt_ids
episode.response_tokens = response.token_ids
episode.response = response.text
input_ids[i, :max_req_tokens] = episode.request_tensor
input_ids[i, max_req_tokens:] = episode.response_tensor
episode.reward = await reward_actor.evaluate_response.choose(
prompt=prompt, response=response.text, target=target
)

# Calculate reference logprobs
# ref_logits = await ref_model.forward.choose(input_ids)
# ref_logprobs = compute_logprobs(ref_logits, input_ids[:, max_req_tokens:])
# for i, episode in enumerate(group.episodes):
# episode.ref_logprobs = ref_logprobs[i]
# del ref_logits, ref_logprobs, input_ids

# Calculate advantages and add to replay buffer
advantages = await compute_advantages.compute.choose(group)
for episode, advantage in zip(group.episodes, advantages):
episode.advantage = advantage
await replay_buffer.add.choose(episode)

# Log metrics
avg_response_len = (
sum(len(e.response_tokens) for e in group.episodes) / group_size
)
mlogger.log("avg_response_len/rollout", avg_response_len, rollout_count)
buffer_size = await replay_buffer._numel.choose()
mlogger.log("buffer_size/rollout", buffer_size, rollout_count)
avg_reward = sum(e.reward for e in group.episodes) / group_size
mlogger.log("avg_reward/rollout", avg_reward, rollout_count)

rollout_count += 1

async def continuous_training():
training_step = 0
while True:
batch = await replay_buffer.sample.choose(curr_policy_version=training_step)
if batch is None:
await asyncio.sleep(0.1)
else:
inputs, targets = batch
loss = await trainer.train_step.choose(inputs, targets)
training_step += 1
mlogger.log("loss/training_step", loss, training_step)
await trainer.push_weights.call(training_step)
await policy.update_weights.call(training_step)

print("Starting GRPO training loops...")
# TODO: Start multiple rollouts once all serivces support it
rollout_task = asyncio.create_task(continuous_rollouts())
training_task = asyncio.create_task(continuous_training())

try:
await asyncio.gather(rollout_task, training_task)
except KeyboardInterrupt:
print("Training interrupted by user")
rollout_task.cancel()
training_task.cancel()
finally:
print("Shutting down...")
await asyncio.gather(
dataloader.shutdown(),
policy.shutdown(),
trainer.shutdown(),
replay_buffer.shutdown(),
compute_advantages.shutdown(),
# ref_model.shutdown(),
reward_actor.shutdown(),
)
# TODO - add a global shutdown that implicitly shuts down all services
# and remote allocations
await shutdown()


if __name__ == "__main__":

@parse
def _main(cfg):
asyncio.run(main(cfg))

_main() # @parse grabs the cfg from CLI
Loading
Loading