Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/unit_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ jobs:
run: python -m pip install torch==2.9.0.dev20250826 --extra-index-url https://download.pytorch.org/whl/nightly/cpu
- name: Install monarch
run: python -m pip install monarch-no-torch==0.1.0.dev20250826 --find-links assets/ci
- name: Install torchstore
run: |
eval "$(ssh-agent -s)"
ssh-add - <<< '${{ secrets.FORGE_GITHUB_CI_FOR_TORCHSTORE }}'
python -m pip install git+ssh://[email protected]/meta-pytorch/torchstore.git
- name: Install dependencies
run: python -m pip install --no-build-isolation -e ".[dev]"
- name: Run unit tests with coverage
Expand Down
98 changes: 57 additions & 41 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

import asyncio
import logging
import time
import uuid
from dataclasses import dataclass
from typing import Any, Callable, Optional
Expand All @@ -21,12 +21,11 @@
from forge.util.metric_logging import get_metric_logger
from monarch.actor import endpoint
from torch import nn
from torchstore import MultiProcessStore
from torchstore._state_dict_utils import DELIM, push_state_dict
from transformers import AutoModelForCausalLM
from vllm.transformers_utils.tokenizer import get_tokenizer

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)


def compute_logprobs(
logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0
Expand Down Expand Up @@ -121,7 +120,7 @@ def new_group(
target: Any = None,
):
episodes = []
for i in range(group_size):
for _ in range(group_size):
episodes.append(
Episode(
episode_id=str(uuid.uuid4()),
Expand All @@ -145,6 +144,8 @@ class Trainer(ForgeActor):
beta: float = 0.1
epsilon: float = 0.1
device: torch.device | None = None
store: MultiProcessStore | None = None
state_dict_key: str = "model_state_dict"

@endpoint
def setup(self):
Expand Down Expand Up @@ -208,11 +209,19 @@ async def train_step(self, batch: list[Episode]):

self.optimizer.step()

return {"loss": loss.item()}
return loss.item()

@endpoint
async def push_weights(self):
pass
async def push_weights(self, version: int):
"""Update policy model weights with trainer's current weights."""
start_time = time.time()
assert self.store is not None, "Store must be provided to save weights"
key = f"{self.state_dict_key}{DELIM}{version}" # Use version as unique id
await push_state_dict(self.store, self.model.state_dict(), key)
end_time = time.time()
self.logger.debug(
f"Pushed weights to {key} in {end_time - start_time:.2f} seconds"
)


@dataclass
Expand All @@ -226,6 +235,9 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl
total_reward = 0.0
for reward_fn in self.reward_functions:
reward = reward_fn(prompt, response, target)
self.logger.info(
f"Response: {response} | Target: {target} | Reward: {reward}"
)
total_reward += reward
return total_reward

Expand All @@ -239,15 +251,8 @@ async def compute(self, group: Group) -> list[float]:
rewards = torch.Tensor([[e.reward for e in group.episodes]])
mean = rewards.mean(1, keepdim=True)
std = rewards.std(1, keepdim=True)

# if std is nan, return 0s. Remove this before shipping
if std.isnan().any():
advantages = torch.zeros_like(rewards)
else:
advantages = (rewards - mean) / (std + 1e-4)

x = advantages.squeeze(0).tolist()
return x
advantages = (rewards - mean) / (std + 1e-4)
return advantages.squeeze(0).tolist()


class RefModel(ForgeActor):
Expand Down Expand Up @@ -328,10 +333,10 @@ async def pad_token(self):

async def main():
"""Main GRPO training loop with rollout and training processes."""
group_size = 1
model = "Qwen/Qwen3-1.7B-Base"
group_size = 5
model = "Qwen/Qwen3-4B-Base"
max_req_tokens = 512
max_res_tokens = 128
max_res_tokens = 512

# ---- Setup WandB Logger ---- #
logger = get_metric_logger(
Expand All @@ -340,6 +345,8 @@ async def main():
project="grpo-training",
)

store = await MultiProcessStore.create_store()

# ---- Setup services ---- #
(
dataloader,
Expand Down Expand Up @@ -368,18 +375,20 @@ async def main():
n=group_size, max_tokens=max_res_tokens
),
),
store=store,
),
spawn_service(
ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1),
Trainer,
learning_rate=1e-5,
model_name=model,
store=store,
),
spawn_service(
ServiceConfig(procs_per_replica=1, num_replicas=1),
ReplayBuffer,
batch_size=4,
max_policy_age=1,
batch_size=8,
max_policy_age=0, # Fully on-policy for now
),
spawn_service(
ServiceConfig(procs_per_replica=1, num_replicas=1),
Expand Down Expand Up @@ -409,7 +418,13 @@ async def continuous_rollouts():
print("Dataloader is empty, exiting continuous rollout")
return
prompt, target = sample["request"], sample["target"]
version = 0 # await policy.get_current_version.choose()
responses = await policy.generate.choose(prompt)
# If weights are updating mid-rollout, response will be cancelled and service
# will return None. We currently throw away the sample.
if responses is None:
continue

version = await policy.get_version.choose()
group = Group.new_group(
group_id=rollout_count,
group_size=group_size,
Expand All @@ -421,12 +436,10 @@ async def continuous_rollouts():
target=target,
)

responses = await policy.generate.choose(prompt)

# TODO: Parallelize the following calculation
for episode, response in zip(group.episodes, responses.outputs):
episode.request_tokens = responses.prompt_token_ids
episode.response_tokens = response.token_ids
assert len(response.token_ids) <= max_res_tokens
episode.ref_logprobs = await ref_model.forward.choose(episode)
episode.reward = await reward_actor.evaluate_response.choose(
prompt=prompt, response=response.text, target=target
Expand All @@ -436,30 +449,33 @@ async def continuous_rollouts():
episode.advantage = advantage
await replay_buffer.add.choose(episode)

avg_response_len = (
sum(len(e.response_tokens) for e in group.episodes) / group_size
)
logger.log("avg_response_len/rollout", avg_response_len, rollout_count)
buffer_size = await replay_buffer._numel.choose()
logger.log("buffer_size/rollout", buffer_size, rollout_count)
avg_reward = sum(e.reward for e in group.episodes) / group_size
logger.log("avg_reward/rollout", avg_reward, rollout_count)

rollout_count += 1
if rollout_count % 10 == 0:
avg_reward = sum(e.reward for e in group.episodes) / len(group.episodes)
print(
f"Generated {rollout_count} rollouts w/ average reward {avg_reward}"
)
logger.log("reward/rollout", avg_reward, rollout_count)

async def continuous_training():
training_step = 0
policy_version = 0
while True:
batch = await replay_buffer.sample.choose(curr_policy_version=0)
batch = await replay_buffer.sample.choose(
curr_policy_version=policy_version
)
if batch is None:
await asyncio.sleep(0.1)
else:
training_result = await trainer.train_step.choose(batch)
loss = await trainer.train_step.choose(batch)
training_step += 1
if training_step % 10 == 0:
print(f"Completed {training_step} training steps")
if training_result:
loss_value = training_result.get("loss", 0.0)
print(f"Latest loss: {loss_value}")
logger.log("loss/training_step", loss_value, training_step)
# await trainer.update_weights(policy)
logger.log("loss/training_step", loss, training_step)
await trainer.push_weights.choose(policy_version)
policy_version += 1
await policy.update_weights.choose()

print("Starting GRPO training loops...")
# TODO: Start multiple rollouts once all serivces support it
Expand Down
44 changes: 38 additions & 6 deletions scripts/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@ set -euo pipefail
# Colors for output
GREEN='\033[0;32m'
RED='\033[0;31m'
YELLOW='\033[0;33m'
NC='\033[0m'

log_info() { echo -e "${GREEN}[INFO]${NC} $1"; }
log_error() { echo -e "${RED}[ERROR]${NC} $1"; }
log_warning() { echo -e "${YELLOW}[WARNING]${NC} $1";}

# Configuration
PYTORCH_VERSION="2.9.0.dev20250828"
Expand All @@ -34,20 +36,49 @@ check_conda_env() {
log_info "Installing in conda environment: $CONDA_DEFAULT_ENV"
}

# Check sudo access
# Check sudo access and if it is not available; continue with Conda
check_sudo() {
if ! sudo -n true 2>/dev/null; then
log_error "This script requires passwordless sudo access for system packages"
log_info "Run 'sudo -v' first, or configure passwordless sudo"
exit 1
log_warning "Passwordless sudo access is not available."
log_info "The script will continue and attempt to install packages via conda instead."
else
log_info "Passwordless sudo access detected."
fi
}

# Install required system packages
install_system_packages() {
log_info "Installing required system packages..."
sudo dnf install -y libibverbs rdma-core libmlx5 libibverbs-devel rdma-core-devel
log_info "System packages installed successfully"
# Check for sudo access
if sudo -n true 2>/dev/null; then
# Detect OS and install packages accordingly
if [ -f /etc/fedora-release ] || [ -f /etc/centos-release ]; then
log_info "Detected Fedora OS"
sudo dnf install -y libibverbs rdma-core libmlx5 libibverbs-devel rdma-core-devel
elif [ -f /etc/lsb-release ] || [ -f /etc/ubuntu-release ]; then
log_info "Detected Ubuntu OS"
sudo apt-get update
sudo apt-get install -y libibverbs1 rdma-core libmlx5-1 libibverbs-dev rdma-core-dev
else
log_error "Unsupported OS for automatic system package installation"
exit 1
fi
log_info "System packages installed successfully"
else
log_warning "No sudo access detected. Attempting to install packages via conda."
conda install -c conda-forge rdma-core libibverbs-cos7-x86_64 -y
log_info "Conda package installation attempted. Please ensure the packages are installed correctly."
fi
}

# Check to see if gh is installed, if not, it will be installed via conda-forge channel
check_gh_install() {
if ! command -v gh &> /dev/null; then
log_warning "GitHub CLI (gh) not found. Installing via Conda..."
conda install gh --channel conda-forge -y
else
log_info "GitHub CLI (gh) already installed."
fi
}

# Check wheels exist
Expand Down Expand Up @@ -126,6 +157,7 @@ main() {
conda install -y openssl

install_system_packages
check_gh_install
download_vllm_wheel

log_info "Installing PyTorch nightly..."
Expand Down
Loading
Loading