Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
from forge.controller.actor import ForgeActor
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service
from forge.data.rewards import MathReward, ThinkingReward
from forge.data.utils import exclude_service
from forge.util.metric_logging import get_metric_logger
from monarch.actor import endpoint
from omegaconf import DictConfig
from src.forge.data.utils import exclude_service
from torch import nn
from torchstore.state_dict_utils import DELIM
from transformers import AutoModelForCausalLM
Expand Down
78 changes: 4 additions & 74 deletions apps/rl/llama3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ trainer:
model:
name: llama3
flavor: 8B
tokenizer_path: /tmp/Meta-Llama-3.1-8B-Instruct

hf_assets_path: /tmp/Meta-Llama-3.1-8B-Instruct

optimizer:
name: AdamW
Expand All @@ -36,7 +35,7 @@ trainer:

parallelism:
data_parallel_replicate_degree: 1
data_parallel_shard_degree: -1
data_parallel_shard_degree: 4
tensor_parallel_degree: 1
pipeline_parallel_degree: 1
context_parallel_degree: 1
Expand All @@ -57,76 +56,7 @@ trainer:
selective_ac_option: op

replay_buffer:
batch_size: 2
batch_size: 4
max_policy_age: 2
seed: None

# policy:
# scheduler:
# scheduler: local # local | mast (not supported yet)
# num_hosts: 1
# num_gpus: 1
# oncall: torchtune
# identity: pytorch_distributed
# image: forge_workspace:latest
#
# model: "meta-llama/Llama-3.1-8B-Instruct"
# tensor_parallel_size: 2
# pipeline_parallel_size: 1
# enforce_eager: false

# postprocessor:
# scheduler:
# scheduler: local # local | mast (not supported yet)
# num_hosts: 1
# num_gpus: 1
# oncall: torchtune
# identity: pytorch_distributed
# image: forge_workspace:latest
#
# comm:
# trace_buf_size: 0
#
# optimizer:
# name: AdamW
# lr: 1e-5
# eps: 1e-8
#
# lr_scheduler:
# warmup_steps: 1
#
# training:
# local_batch_size: 1
# seq_len: 2048
# max_norm: 1.0
# steps: 5
# compile: false
# dataset: "c4"
#
# parallelism:
# data_parallel_replicate_degree: 1
# data_parallel_shard_degree: -1
# tensor_parallel_degree: 1
# pipeline_parallel_degree: 1
# context_parallel_degree: 1
# expert_parallel_degree: 1
# disable_loss_parallel: false
#
# checkpoint:
# enable: true
# folder: /tmp/Meta-Llama-3.1-8B-Instruct/
# interval: 500
# async_mode: "disabled"
#
# activation_checkpoint:
# mode: selective
# selective_ac_option: op
#

# profiling:
# enable_profiling: false

# metrics:
# log_freq: 10
# enable_tensorboard: true
# save_tb_folder: "tb"
dp_size: 4
143 changes: 140 additions & 3 deletions apps/rl/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,32 +13,169 @@
import asyncio
import logging
import sys
from dataclasses import dataclass
from typing import Any

import torch
import torch.nn.functional as F
from forge.actors import ReplayBuffer, RLTrainer
from forge.cli.config import parse
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service
from omegaconf import DictConfig
from torch import Tensor

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


@dataclass
class Episode:
# TODO: add adtional layer for multi-turn
episode_id: str
request: str
policy_version: int
pad_id: int
request_len: int
response_len: int
target: Any | None = None
# processed data
response: str | None = None
request_tokens: list[int] | None = None
response_tokens: list[int] | None = None
ref_logprobs: Tensor | None = None
reward: float | None = None
advantage: float | None = None

@property
def request_tensor(self):
tensor = torch.tensor(self.request_tokens, dtype=torch.long)
if tensor.shape[0] < self.request_len: # left pad
diff = self.request_len - tensor.shape[0]
tensor = F.pad(tensor, (diff, 0), value=self.pad_id)
return tensor

@property
def response_tensor(self):
tensor = torch.tensor(self.response_tokens, dtype=torch.long)
if tensor.shape[0] < self.response_len: # right pad
diff = self.response_len - tensor.shape[0]
tensor = F.pad(tensor, (0, diff), value=self.pad_id)
return tensor


def collate(batches: list[list[Episode]]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return type?

inputs = []
targets = []
for batch in batches:
request = [e.request_tensor for e in batch]
request = torch.stack(request) # [b x s]

response = [e.response_tensor for e in batch]
response = torch.stack(response) # [b x s]

ref_logprobs = [e.ref_logprobs for e in batch]
ref_logprobs = torch.stack(ref_logprobs).squeeze() # [b x s]

advantages = [e.advantage for e in batch]
advantages = torch.tensor(advantages).unsqueeze(-1) # [b x 1]

pad_id = batch[0].pad_id
mask = response != pad_id

input = {"tokens": torch.cat([request, response], dim=1)}
target = {
"response": response,
"ref_logprobs": ref_logprobs,
"advantages": advantages,
"padding_mask": mask,
}
inputs.append(input)
targets.append(target)
return inputs, targets


def compute_logprobs(
logits: Tensor, input_ids: torch.Tensor, temperature: float = 1.0
) -> Tensor:
context_length = logits.shape[1] - input_ids.shape[1]

# Truncate request logits and drop last
logits = logits[:, context_length - 1 : -1]

# Compute logprobs
logprobs = torch.log_softmax(logits / temperature, dim=-1)
logprobs = torch.gather(logprobs, 2, input_ids.unsqueeze(-1)).squeeze(-1)

return logprobs


def simple_grpo_loss(
logits: Tensor,
response: Tensor,
ref_logprobs: Tensor,
advantages: Tensor,
padding_mask: Tensor,
beta: float = 0.1,
):
"""Simplified GRPO Loss for simplified single step updates
Copied from https://github.com/pytorch/torchtune/blob/main/torchtune/dev/grpo/loss.py.
"""
logprobs = compute_logprobs(logits, response)
per_token_kl = (
torch.exp(ref_logprobs.detach() - logprobs)
- (ref_logprobs.detach() - logprobs)
- 1
)
per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages
per_token_loss = -(per_token_policy_loss - beta * per_token_kl)
loss = (
(per_token_loss * padding_mask).sum(dim=1) / (padding_mask.sum(dim=1) + 1e-8)
).mean()
return loss


async def run(cfg: DictConfig):
trainer, replay_buffer = await asyncio.gather(
spawn_service(
ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=4),
ServiceConfig(procs_per_replica=4, with_gpus=True, num_replicas=1),
RLTrainer,
loss=simple_grpo_loss,
**cfg.trainer,
),
spawn_service(
ServiceConfig(procs_per_replica=1, num_replicas=1),
ReplayBuffer,
collate=collate,
**cfg.replay_buffer,
),
)
print("Services initialized....")
print("Services initialized...")

print("Collecting Data...")
g = torch.manual_seed(0)
global_batch_size = cfg.replay_buffer.batch_size * cfg.replay_buffer.dp_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to change right now, but wouldn't the global batch size technically be a trainer config and not a replay buffer config? Is it possible to tell omegaconf like "make this value the same as the one defined elsewhere"?

Copy link
Contributor Author

@pbontrager pbontrager Sep 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a trainer config anymore since replay_buffer is it's own service and the trainer doesn't handle data or loading anymore

for i in range(global_batch_size):
req_len, res_len = torch.randint(64, 256, (2,), generator=g).tolist()
e = Episode(
episode_id=i,
request="",
policy_version=0,
pad_id=0,
request_len=256,
response_len=256,
request_tokens=torch.randint(64_000, (req_len,), generator=g).tolist(),
response_tokens=torch.randint(64_000, (res_len,), generator=g).tolist(),
ref_logprobs=torch.randn((256,), generator=g),
advantage=torch.randn((1,), generator=g),
)
await replay_buffer.add.choose(e)

print("Train step...")
inputs, targets = await replay_buffer.sample.choose(curr_policy_version=0)
outputs = await trainer.train_step.choose(inputs, targets)
print("Loss: ", outputs["loss"])

print("shutting down...")
print("Shutting down...")
await shutdown_service(trainer)
await shutdown_service(replay_buffer)

Expand Down
9 changes: 5 additions & 4 deletions src/forge/actors/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import random
from dataclasses import dataclass
from typing import Any
from typing import Any, Callable

from monarch.actor import endpoint

Expand All @@ -21,6 +21,7 @@ class ReplayBuffer(ForgeActor):
max_policy_age: int
dp_size: int = 1
seed: int | None = None
collate: Callable = lambda batch: batch

@endpoint
async def setup(self) -> None:
Expand All @@ -31,7 +32,7 @@ async def setup(self) -> None:
self.sampler = random.sample

@endpoint
async def add(self, episode) -> None:
async def add(self, episode: "Episode") -> None:
self.buffer.append(episode)

@endpoint
Expand All @@ -55,7 +56,7 @@ async def sample(self, curr_policy_version: int, batch_size: int | None = None):
if total_samples > len(self.buffer):
return None

# TODO: Make this more efficient
# TODO: prefetch samples in advance
idx_to_sample = self.sampler(range(len(self.buffer)), k=total_samples)
# Pop episodes in descending order to avoid shifting issues
popped = [self.buffer.pop(i) for i in sorted(idx_to_sample, reverse=True)]
Expand All @@ -71,7 +72,7 @@ async def sample(self, curr_policy_version: int, batch_size: int | None = None):
for dp_idx in range(self.dp_size)
]

return reshaped_episodes
return self.collate(reshaped_episodes)

@endpoint
async def evict(self, curr_policy_version: int) -> None:
Expand Down
Loading
Loading