Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 110 additions & 38 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,42 @@

import asyncio
import logging
import tempfile
import time
from dataclasses import dataclass
from typing import Callable

import safetensors.torch as safetensors
import torch

from absl import app, flags
from datasets import load_dataset
from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig
from forge.actors.replay_buffer import ReplayBuffer
from forge.controller.actor import ForgeActor
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service
from forge.data.rewards import MathReward, ThinkingReward
from forge.data.weights_handle import WeightsHandle, WeightsHandleType
from forge.util.metric_logging import get_metric_logger
from monarch.actor import endpoint

from transformers import AutoModelForCausalLM, AutoTokenizer

FLAGS = flags.FLAGS

flags.DEFINE_integer("batch_size", 8, "")
flags.DEFINE_integer("update_period", 5, "")
flags.DEFINE_integer("group_size", 16, "")
flags.DEFINE_integer("max_policy_age", 10, "")


def clean_up_temp_dir(temp_dir: str) -> None:
"""Clean up temporary directory."""
import shutil

shutil.rmtree(temp_dir, ignore_errors=True)


logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

Expand Down Expand Up @@ -77,6 +98,9 @@ def __init__(self, episode_id: int, prompt: str, target: str, policy_version: in
def add_group(self, group: Group):
self.groups.append(group)

def add_groups(self, groups: list[Group]):
self.groups.extend(groups)


class Trainer(ForgeActor):
"""GRPO Trainer implementation for policy optimization."""
Expand Down Expand Up @@ -189,31 +213,57 @@ async def train_step(self, batch: list[Episode]):
return {"loss": avg_loss, "groups_processed": num_groups_processed}

@endpoint
async def update_weights(self, policy_actor):
"""Update policy model weights with trainer's current weights."""
async def export_weights(self, step: int) -> WeightsHandle:
"""Export weights to a temp file and return the handle."""
# Time how long it takes to update weights
start_time = time.time()

# Set model to eval mode for weight extraction
self.model.eval()

# Extract current model state dict
model_state_dict = self.model.state_dict()
# Save weights to a memory-backed temporary directory using /dev/shm
import os

# Convert tensors to CPU for transfer (if they're on GPU)
cpu_state_dict = {}
for key, tensor in model_state_dict.items():
cpu_state_dict[key] = tensor.cpu() if tensor.is_cuda else tensor
shm_path = "/dev/shm"
if os.path.exists(shm_path):
# Use shared memory filesystem for memory-backed storage
temp_dir = tempfile.mkdtemp(
prefix=f"model_weights_step_{step:08d}_", dir=shm_path
)
else:
# Fallback to system temp if /dev/shm not available
temp_dir = tempfile.mkdtemp(prefix=f"model_weights_step_{step:08d}_")

# Update the policy actor's model weights
await policy_actor.update_model_weights.choose(cpu_state_dict)
try:
# Save weights directly to SafeTensors file
weights_file = os.path.join(temp_dir, "model_weights.safetensors")
state_dict = {name: param for name, param in self.model.named_parameters()}
safetensors.save_file(state_dict, weights_file)

# Create weights handle with the SafeTensors file path
param_names = list(state_dict.keys())
weights_handle = WeightsHandle(
handle_type=WeightsHandleType.FILE,
version=step,
payload={
"param_names": param_names,
"model_path": weights_file,
"model_name": self.model_name,
},
)

except Exception as e:
# Clean up temporary directory if something goes wrong
clean_up_temp_dir(temp_dir)
raise e

# Set model back to training mode
self.model.train()

# Log the time taken
end_time = time.time()
self.logger.info(f"Updating weights took {end_time - start_time:.2f} seconds")
return weights_handle


class RewardActor(ForgeActor):
Expand Down Expand Up @@ -342,9 +392,9 @@ async def __next__(self) -> dict[str, str] | None:
return None


async def main():
async def _main():
"""Main GRPO training loop with rollout and training processes."""
group_size = 1
group_size = 16
model = "Qwen/Qwen3-1.7B"

# ---- Setup WandB Logger ---- #
Expand Down Expand Up @@ -412,7 +462,6 @@ async def main():
reward_functions=[MathReward(), ThinkingReward()],
),
)

print("All services initialized successfully!")

# ---- Core RL loops ---- #
Expand All @@ -424,61 +473,80 @@ async def continuous_rollouts():
print("Dataloader is empty, exiting continuous rollout")
return
prompt, target = sample["question"], sample["answer"]
version = 0 # await policy.get_current_version.choose()

response = await policy.generate.choose(prompt)
actions = response.completions
version = response.policy_version
episode = Episode(
episode_id=rollout_count,
prompt=prompt,
target=target,
policy_version=version,
)
actions = await policy.generate.choose(prompt)
for action in actions:
ref_logprobs = await ref_model.forward.choose(action.token_ids)
reward = await reward_actor.evaluate_response.choose(
prompt=prompt, response=action.text, target=target

async def _get_group(action):
ref_logprobs, reward = await asyncio.gather(
ref_model.forward.choose(action.token_ids),
reward_actor.evaluate_response.choose(
prompt=prompt, response=action.text, target=target
),
)
episode.add_group(
Group(
response=action.text,
ref_logprobs=ref_logprobs,
reward=reward,
)
return Group(
response=action.text, ref_logprobs=ref_logprobs, reward=reward
)

groups = await asyncio.gather(*[_get_group(action) for action in actions])

episode.add_groups(groups)

advantages = await compute_advantages.__call__.choose(episode.groups)
for advantage, group in zip(advantages, episode.groups):
group.advantage = advantage

await replay_buffer.add.choose(episode)

rollout_count += 1
rewards = []
if rollout_count % 10 == 0:
avg_reward = sum(group.reward for group in episode.groups) / len(
episode.groups
)
episode_avg_reward = sum(
group.reward for group in episode.groups
) / len(episode.groups)
rewards.append(episode_avg_reward)
avg_reward = sum(rewards) / len(rewards)
rewards.clear()
print(
f"Generated {rollout_count} rollouts w/ average reward {avg_reward}"
f"Generated {rollout_count} rollouts, average reward of last 10 = {avg_reward}"
)
logger.log("reward/rollout", avg_reward, rollout_count)

async def continuous_training():
training_step = 0
update_period = FLAGS.update_period
# using training_step as the policy version for now, open to suggestions
while True:
batch = await replay_buffer.sample.choose(curr_policy_version=0)
batch = await replay_buffer.sample.choose(curr_policy_version=training_step)
if batch is None:
await asyncio.sleep(0.1)
else:
# why is call_one not defined?
# training_result = await trainer.train_step.call_one(batch)
training_result = await trainer.train_step.choose(batch)
training_step += 1
if training_step % 10 == 0:
print(f"Completed {training_step} training steps")
if training_result:
loss_value = training_result.get("loss", 0.0)
print(f"Latest loss: {loss_value}")
logger.log("loss/training_step", loss_value, training_step)
# await trainer.update_weights(policy)
print(f"Completed {training_step} training steps")
if training_result:
loss_value = training_result.get("loss", 0.0)
print(f"Latest loss: {loss_value}")
logger.log("loss/training_step", loss_value, training_step)
if training_step % update_period == 0:
print(f"Exporting policy weights @ {training_step=}")
weights_handle = await trainer.export_weights.choose(training_step)
print(f"Exported weights @ {training_step=}")
await policy.update_weights.call(weights_handle)
print(f"Updated policy weights to version @ {training_step=}")
clean_up_temp_dir(weights_handle.payload["model_path"])

print("Starting GRPO training loops...")

rollout_task = asyncio.create_task(continuous_rollouts())
training_task = asyncio.create_task(continuous_training())

Expand All @@ -501,5 +569,9 @@ async def continuous_training():
)


def main(argv):
asyncio.run(_main())


if __name__ == "__main__":
asyncio.run(main())
app.run(main)
12 changes: 10 additions & 2 deletions apps/vllm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@
from argparse import Namespace
from typing import List

from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig
from forge.actors.policy import (
CompletionPolicyResponse,
Policy,
PolicyConfig,
SamplingOverrides,
WorkerConfig,
)
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service
from vllm.outputs import CompletionOutput

Expand Down Expand Up @@ -89,7 +95,9 @@ async def run_vllm(service_config: ServiceConfig, config: PolicyConfig, prompt:

async with policy.session():
print("Requesting generation...")
responses: List[CompletionOutput] = await policy.generate.choose(prompt=prompt)
responses: List[CompletionOutput] = await policy.generate.choose(
prompt=prompt
).completions

print("\nGeneration Results:")
print("=" * 80)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ dependencies = [
# Miscellaneous
"omegaconf",
"wandb",
"aiorwlock",
]
dynamic = ["version"]

Expand Down
Loading
Loading