diff --git a/.meta/mast/main.py b/.meta/mast/main.py
index 513d96fc6..b764a7630 100644
--- a/.meta/mast/main.py
+++ b/.meta/mast/main.py
@@ -15,7 +15,7 @@
MastLauncher,
mount_mnt_directory,
)
-from forge.controller.provisioner import init_provisioner
+from forge.controller.provisioner import get_or_create_provisioner
from forge.types import (
Launcher,
@@ -68,7 +68,9 @@ async def main(cfg: DictConfig, mode: str = "detached", extra_args: list = None)
else:
# In remote mode, we're already running inside MAST, so mount directory, init provisioner and run training
mount_mnt_directory("/mnt/wsfuse")
- await init_provisioner(ProvisionerConfig(launcher_config=launcher_config))
+ await get_or_create_provisioner(
+ ProvisionerConfig(launcher_config=launcher_config)
+ )
await grpo_main(cfg)
diff --git a/apps/grpo/main.py b/apps/grpo/main.py
index 85872681f..bbd64f415 100644
--- a/apps/grpo/main.py
+++ b/apps/grpo/main.py
@@ -25,7 +25,7 @@
from forge.actors.replay_buffer import ReplayBuffer
from forge.actors.trainer import RLTrainer
from forge.controller.actor import ForgeActor
-from forge.controller.provisioner import init_provisioner, shutdown
+from forge.controller.provisioner import get_or_create_provisioner, shutdown
from forge.data.rewards import MathReward, ThinkingReward
from forge.data_models.completion import Completion
from forge.observability.metric_actors import get_or_create_metric_logger
@@ -298,11 +298,11 @@ async def main(cfg: DictConfig):
# ---- Global setups ---- #
provisioner = None
if cfg.get("provisioner", None) is not None:
- provisioner = await init_provisioner(
+ provisioner = await get_or_create_provisioner(
ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner))
)
else:
- provisioner = await init_provisioner()
+ provisioner = await get_or_create_provisioner()
metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}})
mlogger = await get_or_create_metric_logger(process_name="Controller")
@@ -346,7 +346,7 @@ async def main(cfg: DictConfig):
# TODO: support multiple host meshes
trainer_num_procs = cfg.actors.trainer["procs"]
trainer_host_mesh_name = cfg.actors.trainer["mesh_name"]
- trainer_hosts = provisioner.get_host_mesh(trainer_host_mesh_name)
+ trainer_hosts = provisioner.get_host_mesh.call_one(trainer_host_mesh_name)
await ts.initialize(
mesh=trainer_hosts.spawn_procs(per_host={"procs": trainer_num_procs}),
strategy=ts.LocalRankStrategy(),
diff --git a/apps/grpo/notebook.ipynb b/apps/grpo/notebook.ipynb
index 8d9fbc75a..ceebb997b 100644
--- a/apps/grpo/notebook.ipynb
+++ b/apps/grpo/notebook.ipynb
@@ -1,669 +1,672 @@
{
- "cells": [
- {
- "cell_type": "markdown",
- "id": "46c66f45-4be3-4674-a870-3849c1048ddb",
- "metadata": {},
- "source": [
- "# GRPO for Math (GSM8k)\n",
- "\n",
- "## Import modules"
- ]
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "46c66f45-4be3-4674-a870-3849c1048ddb",
+ "metadata": {},
+ "source": [
+ "# GRPO for Math (GSM8k)\n",
+ "\n",
+ "## Import modules"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "97d9ca00-92a8-4bd3-9b2b-ab8856f5acce",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Copyright (c) Meta Platforms, Inc. and affiliates.\n",
+ "# All rights reserved.\n",
+ "#\n",
+ "# This source code is licensed under the BSD-style license found in the\n",
+ "# LICENSE file in the root directory of this source tree.\n",
+ "\n",
+ "import asyncio\n",
+ "import time\n",
+ "import uuid\n",
+ "from dataclasses import dataclass\n",
+ "from typing import Any, Callable\n",
+ "\n",
+ "import torch\n",
+ "import torch.nn.functional as F\n",
+ "import torchstore as ts\n",
+ "from datasets import load_dataset\n",
+ "from forge.actors._torchstore_utils import (\n",
+ " get_dcp_whole_state_dict_key,\n",
+ " get_param_prefix,\n",
+ ")\n",
+ "from forge.actors.generator import Generator as Policy\n",
+ "from forge.actors.reference_model import ReferenceModel\n",
+ "from forge.actors.replay_buffer import ReplayBuffer\n",
+ "from forge.actors.trainer import RLTrainer\n",
+ "from forge.cli.config import parse\n",
+ "from forge.controller.actor import ForgeActor\n",
+ "from forge.controller.provisioner import get_or_create_provisioner, shutdown\n",
+ "from forge.data.rewards import MathReward, ThinkingReward\n",
+ "from forge.observability.metric_actors import get_or_create_metric_logger\n",
+ "from forge.observability.metrics import record_metric, Reduce\n",
+ "from forge.observability.perf_tracker import Tracer\n",
+ "\n",
+ "from forge.types import LauncherConfig, ProvisionerConfig\n",
+ "from forge.util.ops import compute_logprobs\n",
+ "from monarch.actor import endpoint\n",
+ "from omegaconf import DictConfig\n",
+ "from vllm.transformers_utils.tokenizer import get_tokenizer\n",
+ "\n",
+ "import os\n",
+ "os.environ[\"MONARCH_HOSTMESH_V1\"] = \"1\"\n",
+ "os.environ[\"TORCHSTORE_RDMA_ENABLED\"] = \"1\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "34d4319f-e6c9-4f4b-9b92-c572de08f0b2",
+ "metadata": {},
+ "source": [
+ "## Define Data Structures"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "b4a25e9d-e1dd-4ea7-a80c-383a2c04656a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@dataclass\n",
+ "class Episode:\n",
+ " # TODO: add adtional layer for multi-turn\n",
+ " episode_id: str\n",
+ " request: str\n",
+ " policy_version: int\n",
+ " pad_id: int\n",
+ " request_len: int\n",
+ " response_len: int\n",
+ " target: Any | None = None\n",
+ " # processed data\n",
+ " response: str | None = None\n",
+ " request_tokens: list[int] | None = None\n",
+ " response_tokens: list[int] | None = None\n",
+ " ref_logprobs: torch.Tensor | None = None\n",
+ " reward: float | None = None\n",
+ " advantage: float | None = None\n",
+ "\n",
+ " @property\n",
+ " def request_tensor(self):\n",
+ " tensor = torch.tensor(self.request_tokens, dtype=torch.long)\n",
+ " if tensor.shape[0] < self.request_len: # left pad\n",
+ " diff = self.request_len - tensor.shape[0]\n",
+ " tensor = F.pad(tensor, (diff, 0), value=self.pad_id)\n",
+ " return tensor\n",
+ "\n",
+ " @property\n",
+ " def response_tensor(self):\n",
+ " tensor = torch.tensor(self.response_tokens, dtype=torch.long)\n",
+ " if tensor.shape[0] < self.response_len: # right pad\n",
+ " diff = self.response_len - tensor.shape[0]\n",
+ " tensor = F.pad(tensor, (0, diff), value=self.pad_id)\n",
+ " return tensor\n",
+ "\n",
+ "\n",
+ "@dataclass\n",
+ "class Group:\n",
+ " group_id: str\n",
+ " episodes: list[Episode]\n",
+ "\n",
+ " @classmethod\n",
+ " def new_group(\n",
+ " cls,\n",
+ " group_id: int,\n",
+ " group_size: int,\n",
+ " request: str,\n",
+ " policy_version: int,\n",
+ " pad_id: int,\n",
+ " request_len: int,\n",
+ " response_len: int,\n",
+ " target: Any = None,\n",
+ " ):\n",
+ " episodes = []\n",
+ " for _ in range(group_size):\n",
+ " episodes.append(\n",
+ " Episode(\n",
+ " episode_id=str(uuid.uuid4()),\n",
+ " request=request,\n",
+ " policy_version=policy_version,\n",
+ " pad_id=pad_id,\n",
+ " request_len=request_len,\n",
+ " response_len=response_len,\n",
+ " target=target,\n",
+ " )\n",
+ " )\n",
+ " return cls(str(group_id), episodes)\n",
+ "\n",
+ "\n",
+ "def collate(batches: list[list[Episode]]):\n",
+ " inputs = []\n",
+ " targets = []\n",
+ " for batch in batches:\n",
+ " request = [e.request_tensor for e in batch]\n",
+ " request = torch.stack(request) # [b x s]\n",
+ "\n",
+ " response = [e.response_tensor for e in batch]\n",
+ " response = torch.stack(response) # [b x s]\n",
+ "\n",
+ " ref_logprobs = [e.ref_logprobs for e in batch]\n",
+ " ref_logprobs = torch.stack(ref_logprobs).squeeze() # [b x s]\n",
+ "\n",
+ " advantages = [e.advantage for e in batch]\n",
+ " advantages = torch.tensor(advantages).unsqueeze(-1) # [b x 1]\n",
+ "\n",
+ " pad_id = batch[0].pad_id\n",
+ " mask = response != pad_id\n",
+ "\n",
+ " input = {\"tokens\": torch.cat([request, response], dim=1)}\n",
+ " target = {\n",
+ " \"response\": response,\n",
+ " \"ref_logprobs\": ref_logprobs,\n",
+ " \"advantages\": advantages,\n",
+ " \"padding_mask\": mask,\n",
+ " }\n",
+ " inputs.append(input)\n",
+ " targets.append(target)\n",
+ " return inputs, targets\n",
+ "\n",
+ "@dataclass\n",
+ "class DatasetActor(ForgeActor):\n",
+ " \"\"\"Actor wrapper for HuggingFace dataset to provide async interface.\"\"\"\n",
+ "\n",
+ " path: str = \"openai/gsm8k\"\n",
+ " revision: str = \"main\"\n",
+ " data_split: str = \"train\"\n",
+ " streaming: bool = True\n",
+ " model: str = \"Qwen/Qwen3-1.7B\"\n",
+ "\n",
+ " @endpoint\n",
+ " def setup(self):\n",
+ " self._tokenizer = get_tokenizer(self.model)\n",
+ "\n",
+ " def gsm8k_transform(sample):\n",
+ " system_prompt = \"\"\"\n",
+ " Put all your scratchpad work between and tags.\n",
+ " Your final answer should be between and tags otherwise it will not be scored.\n",
+ " \"\"\"\n",
+ " request: str = sample[\"question\"]\n",
+ " as_chat = [\n",
+ " {\"role\": \"system\", \"content\": system_prompt},\n",
+ " {\"role\": \"user\", \"content\": request},\n",
+ " ]\n",
+ " formatted_request = self._tokenizer.apply_chat_template(\n",
+ " as_chat,\n",
+ " tokenize=False,\n",
+ " add_generation_prompt=True,\n",
+ " )\n",
+ " target: str = sample[\"answer\"]\n",
+ " formatted_target = target.split(\"#### \")[1]\n",
+ " return {\"request\": formatted_request, \"target\": formatted_target}\n",
+ "\n",
+ " ds = load_dataset(\n",
+ " self.path, self.revision, split=self.data_split, streaming=self.streaming\n",
+ " )\n",
+ " ds = ds.map(gsm8k_transform)\n",
+ " ds = ds.shuffle()\n",
+ " self._iterator = iter(ds)\n",
+ "\n",
+ " @endpoint\n",
+ " async def sample(self) -> dict[str, str] | None:\n",
+ " try:\n",
+ " sample = next(self._iterator)\n",
+ "\n",
+ " # Record dataset metrics\n",
+ " record_metric(\"dataset/sample/count_samples_generated\", 1, Reduce.SUM)\n",
+ " record_metric(\n",
+ " \"dataset/sample/avg_sample_len\",\n",
+ " len(sample[\"request\"]),\n",
+ " Reduce.MEAN,\n",
+ " )\n",
+ "\n",
+ " return sample\n",
+ " except StopIteration:\n",
+ " return None\n",
+ "\n",
+ " @endpoint\n",
+ " async def pad_token(self):\n",
+ " return self._tokenizer.pad_token_id"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "901b3d1d-7eba-4464-b881-48c11ff6e0ef",
+ "metadata": {},
+ "source": [
+ "## Define loss"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "934aca32-0953-4945-9f99-e7b34804443b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def simple_grpo_loss(\n",
+ " logits: torch.Tensor,\n",
+ " response: torch.Tensor,\n",
+ " ref_logprobs: torch.Tensor,\n",
+ " advantages: torch.Tensor,\n",
+ " padding_mask: torch.Tensor,\n",
+ " beta: float = 0.1,\n",
+ ") -> torch.Tensor:\n",
+ " \"\"\"\n",
+ " Example GRPO Loss Function for RLTrainer\n",
+ " \"\"\"\n",
+ " logprobs: torch.Tensor = compute_logprobs(logits, response)\n",
+ "\n",
+ " # Note: This is also available in losses.grpo_loss via `SimpleGRPOLoss`\n",
+ " kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1\n",
+ " per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages\n",
+ " per_token_loss = -(per_token_policy_loss - beta * kl)\n",
+ " loss = (\n",
+ " ((per_token_loss * padding_mask).sum(dim=1))\n",
+ " / (padding_mask.sum(dim=1).clamp(min=1.0))\n",
+ " ).mean()\n",
+ " return loss"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d4f8bbe3-b7ac-4905-b197-f10990f9a104",
+ "metadata": {},
+ "source": [
+ "## Define Reward"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "163e98bf-e0f5-4ec3-9690-9839e687f9b3",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@dataclass\n",
+ "class RewardActor(ForgeActor):\n",
+ " \"\"\"Reward actor that uses a list of scoring functions.\"\"\"\n",
+ "\n",
+ " reward_functions: list[Callable]\n",
+ "\n",
+ " @endpoint\n",
+ " async def evaluate_response(self, prompt: str, response: str, target: str) -> float:\n",
+ " total_rewards = 0.0\n",
+ " for reward_fn in self.reward_functions:\n",
+ " reward = reward_fn(prompt, response, target)\n",
+ " total_rewards += reward\n",
+ "\n",
+ " # Get a name for the reward function (works for classes, functions, lambdas)\n",
+ " reward_fn_name = getattr(\n",
+ " reward_fn, \"__name__\", reward_fn.__class__.__name__\n",
+ " )\n",
+ " # per function reward\n",
+ " record_metric(\n",
+ " f\"reward/evaluate_response/sum_{reward_fn_name}_reward\",\n",
+ " reward,\n",
+ " Reduce.SUM,\n",
+ " )\n",
+ " record_metric(\n",
+ " f\"reward/evaluate_response/avg_{reward_fn_name}_reward\",\n",
+ " reward,\n",
+ " Reduce.MEAN,\n",
+ " )\n",
+ " record_metric(\n",
+ " f\"reward/evaluate_response/std_{reward_fn_name}_reward\",\n",
+ " reward,\n",
+ " Reduce.STD,\n",
+ " )\n",
+ "\n",
+ " # avg total reward\n",
+ " record_metric(\n",
+ " \"reward/evaluate_response/avg_total_reward\",\n",
+ " reward,\n",
+ " Reduce.MEAN,\n",
+ " )\n",
+ "\n",
+ " # count fn calls\n",
+ " record_metric(\n",
+ " f\"reward/evaluate_response/count_{reward_fn_name}_calls\",\n",
+ " 1,\n",
+ " Reduce.SUM,\n",
+ " )\n",
+ "\n",
+ " avg_reward = total_rewards / len(self.reward_functions)\n",
+ " return avg_reward\n",
+ "\n",
+ "\n",
+ "@dataclass\n",
+ "class ComputeAdvantages(ForgeActor):\n",
+ " \"\"\"Compute advantages for GRPO using reward signals.\"\"\"\n",
+ "\n",
+ " @endpoint\n",
+ " async def compute(self, group: Group) -> list[float]:\n",
+ " # TODO: add batch processing\n",
+ " rewards = torch.tensor([[e.reward for e in group.episodes]])\n",
+ " mean = rewards.mean(1, keepdim=True)\n",
+ " std = rewards.std(1, keepdim=True)\n",
+ " advantages = (rewards - mean) / (std + 1e-4)\n",
+ " return advantages.squeeze(0).tolist()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "88523484-b414-41db-bd3f-0d8dbf881a85",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "async def drop_weights(version: int):\n",
+ " print(f\"Dropping weights @ version {version}\")\n",
+ " start_time = time.perf_counter()\n",
+ " prefix = get_param_prefix(version)\n",
+ " matching_keys = await ts.keys(prefix)\n",
+ " # TODO: once we have something like `get_meta()` in torchstore, we can just\n",
+ " # query the type of the object instead of relying on keys.\n",
+ " dcp_key = get_dcp_whole_state_dict_key(version)\n",
+ " if dcp_key in matching_keys:\n",
+ " dcp_handle = await ts.get(dcp_key)\n",
+ " dcp_handle.drop()\n",
+ " for key in matching_keys:\n",
+ " await ts.delete(key)\n",
+ " elapsed = time.perf_counter() - start_time\n",
+ " print(f\"Dropped weights @ version {version}, took {elapsed:.2f} seconds\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "95d4fef3-180b-4b7e-8871-ecbe113cde72",
+ "metadata": {},
+ "source": [
+ "## Setup Services"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "6c811974-cd6b-40ed-a179-4511a7a6c489",
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "from omegaconf import OmegaConf\n",
+ "from forge.cli.config import resolve_hf_hub_paths\n",
+ "\n",
+ "cfg = OmegaConf.load('apps/grpo/qwen3_1_7b.yaml')\n",
+ "cfg = resolve_hf_hub_paths(cfg)\n",
+ "OmegaConf.resolve(cfg)\n",
+ "\n",
+ "group_size = cfg.group_size # 8\n",
+ "max_req_tokens = cfg.max_req_tokens # 512\n",
+ "max_res_tokens = cfg.max_res_tokens # 512\n",
+ "\n",
+ "metric_logging_cfg = cfg.get(\"metric_logging\", {\"console\": {\"log_per_rank\": False}})\n",
+ "mlogger = await get_or_create_metric_logger()\n",
+ "await mlogger.init_backends.call_one(metric_logging_cfg)\n",
+ "await ts.initialize(strategy=ts.ControllerStorageVolumes())\n",
+ "\n",
+ "dataloader, policy, trainer, replay_buffer, compute_advantages, ref_model, reward_actor = await asyncio.gather(\n",
+ " DatasetActor.options(**cfg.actors.dataset).as_actor(**cfg.dataset),\n",
+ " Policy.options(**cfg.services.policy).as_service(**cfg.policy),\n",
+ " RLTrainer.options(**cfg.actors.trainer).as_actor(\n",
+ " **cfg.trainer, loss=simple_grpo_loss\n",
+ " ),\n",
+ " ReplayBuffer.options(**cfg.actors.replay_buffer).as_actor(\n",
+ " **cfg.replay_buffer, collate=collate\n",
+ " ),\n",
+ " ComputeAdvantages.options(**cfg.actors.compute_advantages).as_actor(),\n",
+ " ReferenceModel.options(**cfg.services.ref_model).as_service(**cfg.ref_model),\n",
+ " RewardActor.options(**cfg.services.reward_actor).as_service(\n",
+ " reward_functions=[MathReward(), ThinkingReward()]\n",
+ " ),\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "3f2a305f-b1e2-4eac-803c-71bf3225fed7",
+ "metadata": {},
+ "source": [
+ "## Rollout Loop"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c1c676fb-2cd6-4c2c-87d4-e1b8cd0b87af",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "async def continuous_rollouts():\n",
+ " rollout_count = 0\n",
+ " pad_id = await dataloader.pad_token.call_one()\n",
+ " while True:\n",
+ " t = Tracer(\"main_perf/continuous_rollouts\")\n",
+ " t.start()\n",
+ " sample = await dataloader.sample.call_one()\n",
+ " if sample is None:\n",
+ " print(\"Dataloader is empty, exiting continuous rollout\")\n",
+ " return\n",
+ "\n",
+ " t.step(\"data_loading\")\n",
+ "\n",
+ " prompt, target = sample[\"request\"], sample[\"target\"]\n",
+ " responses = await policy.generate.route(prompt)\n",
+ " # TODO: this shall be part of the responses metadata instead of a separate call\n",
+ " version = await policy.get_version.route()\n",
+ "\n",
+ " t.step(\"policy_generation\")\n",
+ "\n",
+ " assert (\n",
+ " len(responses) > 0\n",
+ " ), \"Sanity check: Responses should NEVER return empty\"\n",
+ " assert (\n",
+ " version := responses[0].generator_version\n",
+ " ) is not None, \"Response must indicate a version\"\n",
+ " group = Group.new_group(\n",
+ " group_id=rollout_count,\n",
+ " group_size=group_size,\n",
+ " request=prompt,\n",
+ " policy_version=version,\n",
+ " pad_id=pad_id,\n",
+ " request_len=max_req_tokens,\n",
+ " response_len=max_res_tokens,\n",
+ " target=target,\n",
+ " )\n",
+ "\n",
+ " input_ids = torch.ones(\n",
+ " (group_size, max_req_tokens + max_res_tokens),\n",
+ " dtype=torch.long,\n",
+ " device=\"cuda\",\n",
+ " )\n",
+ " # Populate episode info and calculate rewards\n",
+ " for i, (episode, response) in enumerate(zip(group.episodes, responses)):\n",
+ " episode.request_tokens = response.prompt_ids\n",
+ " episode.response_tokens = response.token_ids\n",
+ " episode.response = response.text\n",
+ " input_ids[i, :max_req_tokens] = episode.request_tensor\n",
+ " input_ids[i, max_req_tokens:] = episode.response_tensor\n",
+ " episode.reward = await reward_actor.evaluate_response.route(\n",
+ " prompt=prompt, response=response.text, target=target\n",
+ " )\n",
+ "\n",
+ " t.step(\"reward_evaluation\")\n",
+ "\n",
+ " ref_logprobs = await ref_model.forward.route(\n",
+ " input_ids, max_req_tokens, return_logprobs=True\n",
+ " )\n",
+ " t.step(\"reference_model_calculate_logprobs\")\n",
+ "\n",
+ " for i, episode in enumerate(group.episodes):\n",
+ " episode.ref_logprobs = ref_logprobs[i]\n",
+ " del ref_logprobs, input_ids\n",
+ " t.step(\"compute_logprobs\")\n",
+ "\n",
+ " # Calculate advantages and add to replay buffer\n",
+ " advantages = await compute_advantages.compute.call_one(group)\n",
+ " for episode, advantage in zip(group.episodes, advantages):\n",
+ " episode.advantage = advantage\n",
+ " await replay_buffer.add.call_one(episode)\n",
+ "\n",
+ " # Log metrics\n",
+ " rollout_count += 1\n",
+ " record_metric(\n",
+ " \"main/continuous_rollouts/count_rollout_iterations\", 1, Reduce.SUM\n",
+ " )\n",
+ " t.stop()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "57c316dc-11b5-48ea-8b03-e1bb9d9d1f2b",
+ "metadata": {},
+ "source": [
+ "## Training Loop"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "916a0e79-aded-4ee3-b1a8-db0e772996c9",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "async def continuous_training():\n",
+ " training_step = 0\n",
+ " restart_tracer = True # Flag to control when to restart tracer\n",
+ " while True:\n",
+ " # Restart tracer when needed (initial start or after completing a training step)\n",
+ " # Otherwise, we cannot measure time waiting for buffer\n",
+ " if restart_tracer:\n",
+ " t = Tracer(\"main_perf/continuous_training\")\n",
+ " t.start()\n",
+ " restart_tracer = False\n",
+ "\n",
+ " batch = await replay_buffer.sample.call_one(\n",
+ " curr_policy_version=training_step\n",
+ " )\n",
+ " if batch is None:\n",
+ " await asyncio.sleep(0.1)\n",
+ " else:\n",
+ " t.step(\"waiting_for_buffer\")\n",
+ "\n",
+ " inputs, targets = batch\n",
+ " await trainer.train_step.call(inputs, targets)\n",
+ " training_step += 1\n",
+ " t.step(\"train_step\")\n",
+ "\n",
+ " await trainer.push_weights.call(training_step)\n",
+ " t.step(\"push_weights\")\n",
+ "\n",
+ " await policy.update_weights.fanout(training_step)\n",
+ " update_task = asyncio.create_task(policy.update_weights.fanout(training_step))\n",
+ " t.step(\"update_weights\")\n",
+ "\n",
+ " if training_step >= 2:\n",
+ " await drop_weights(training_step - 1)\n",
+ " t.step(\"drop_weights\")\n",
+ "\n",
+ " t.stop()\n",
+ " restart_tracer = True\n",
+ "\n",
+ " # Flush metrics every training step to WandB\n",
+ " await mlogger.flush.call_one(training_step)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "4542863b-59c5-40bc-896c-6d8d44ada00f",
+ "metadata": {},
+ "source": [
+ "## Run"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "58194c13-b75e-405d-ab11-18cbe1874d92",
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "num_rollout_threads = 1\n",
+ "num_training_threads = 1\n",
+ "\n",
+ "rollout_tasks = [\n",
+ " asyncio.create_task(continuous_rollouts()) for _ in range(num_rollout_threads)\n",
+ "]\n",
+ "training_task = asyncio.create_task(continuous_training())\n",
+ "\n",
+ "try:\n",
+ " await asyncio.gather(*rollout_tasks, training_task)\n",
+ "except KeyboardInterrupt:\n",
+ " print(\"Training interrupted by user\")\n",
+ " for rollout_task in rollout_tasks:\n",
+ " rollout_task.cancel()\n",
+ " training_task.cancel()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b4603b80-1f25-49a1-920e-d24f38dfc687",
+ "metadata": {},
+ "source": [
+ "## Shutdown"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "0d74e781-c253-4bd0-929f-bd4ad516ba81",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "await mlogger.shutdown.call_one()\n",
+ "await asyncio.sleep(2)\n",
+ "\n",
+ "await asyncio.gather(\n",
+ " DatasetActor.shutdown(dataloader),\n",
+ " policy.shutdown(),\n",
+ " RLTrainer.shutdown(trainer),\n",
+ " ReplayBuffer.shutdown(replay_buffer),\n",
+ " ComputeAdvantages.shutdown(compute_advantages),\n",
+ " ref_model.shutdown(),\n",
+ " reward_actor.shutdown(),\n",
+ ")\n",
+ "await shutdown()"
+ ]
+ }
+ ],
+ "metadata": {
+ "fileHeader": "",
+ "fileUid": "33e39445-9f70-425a-b663-fe2c861ef07b",
+ "isAdHoc": false,
+ "kernelspec": {
+ "display_name": "forge",
+ "language": "python",
+ "name": "forge"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.18"
+ }
},
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "97d9ca00-92a8-4bd3-9b2b-ab8856f5acce",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Copyright (c) Meta Platforms, Inc. and affiliates.\n",
- "# All rights reserved.\n",
- "#\n",
- "# This source code is licensed under the BSD-style license found in the\n",
- "# LICENSE file in the root directory of this source tree.\n",
- "\n",
- "import asyncio\n",
- "import time\n",
- "import uuid\n",
- "from dataclasses import dataclass\n",
- "from typing import Any, Callable\n",
- "\n",
- "import torch\n",
- "import torch.nn.functional as F\n",
- "import torchstore as ts\n",
- "from datasets import load_dataset\n",
- "from forge.actors._torchstore_utils import (\n",
- " get_dcp_whole_state_dict_key,\n",
- " get_param_prefix,\n",
- ")\n",
- "from forge.actors.generator import Generator as Policy\n",
- "from forge.actors.reference_model import ReferenceModel\n",
- "from forge.actors.replay_buffer import ReplayBuffer\n",
- "from forge.actors.trainer import RLTrainer\n",
- "from forge.cli.config import parse\n",
- "from forge.controller.actor import ForgeActor\n",
- "from forge.controller.provisioner import init_provisioner, shutdown\n",
- "from forge.data.rewards import MathReward, ThinkingReward\n",
- "from forge.observability.metric_actors import get_or_create_metric_logger\n",
- "from forge.observability.metrics import record_metric, Reduce\n",
- "from forge.observability.perf_tracker import Tracer\n",
- "\n",
- "from forge.types import LauncherConfig, ProvisionerConfig\n",
- "from forge.util.ops import compute_logprobs\n",
- "from monarch.actor import endpoint\n",
- "from omegaconf import DictConfig\n",
- "from vllm.transformers_utils.tokenizer import get_tokenizer\n",
- "\n",
- "import os\n",
- "os.environ[\"MONARCH_HOSTMESH_V1\"] = \"1\"\n",
- "os.environ[\"TORCHSTORE_RDMA_ENABLED\"] = \"1\""
- ]
- },
- {
- "cell_type": "markdown",
- "id": "34d4319f-e6c9-4f4b-9b92-c572de08f0b2",
- "metadata": {},
- "source": [
- "## Define Data Structures"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "b4a25e9d-e1dd-4ea7-a80c-383a2c04656a",
- "metadata": {},
- "outputs": [],
- "source": [
- "@dataclass\n",
- "class Episode:\n",
- " # TODO: add adtional layer for multi-turn\n",
- " episode_id: str\n",
- " request: str\n",
- " policy_version: int\n",
- " pad_id: int\n",
- " request_len: int\n",
- " response_len: int\n",
- " target: Any | None = None\n",
- " # processed data\n",
- " response: str | None = None\n",
- " request_tokens: list[int] | None = None\n",
- " response_tokens: list[int] | None = None\n",
- " ref_logprobs: torch.Tensor | None = None\n",
- " reward: float | None = None\n",
- " advantage: float | None = None\n",
- "\n",
- " @property\n",
- " def request_tensor(self):\n",
- " tensor = torch.tensor(self.request_tokens, dtype=torch.long)\n",
- " if tensor.shape[0] < self.request_len: # left pad\n",
- " diff = self.request_len - tensor.shape[0]\n",
- " tensor = F.pad(tensor, (diff, 0), value=self.pad_id)\n",
- " return tensor\n",
- "\n",
- " @property\n",
- " def response_tensor(self):\n",
- " tensor = torch.tensor(self.response_tokens, dtype=torch.long)\n",
- " if tensor.shape[0] < self.response_len: # right pad\n",
- " diff = self.response_len - tensor.shape[0]\n",
- " tensor = F.pad(tensor, (0, diff), value=self.pad_id)\n",
- " return tensor\n",
- "\n",
- "\n",
- "@dataclass\n",
- "class Group:\n",
- " group_id: str\n",
- " episodes: list[Episode]\n",
- "\n",
- " @classmethod\n",
- " def new_group(\n",
- " cls,\n",
- " group_id: int,\n",
- " group_size: int,\n",
- " request: str,\n",
- " policy_version: int,\n",
- " pad_id: int,\n",
- " request_len: int,\n",
- " response_len: int,\n",
- " target: Any = None,\n",
- " ):\n",
- " episodes = []\n",
- " for _ in range(group_size):\n",
- " episodes.append(\n",
- " Episode(\n",
- " episode_id=str(uuid.uuid4()),\n",
- " request=request,\n",
- " policy_version=policy_version,\n",
- " pad_id=pad_id,\n",
- " request_len=request_len,\n",
- " response_len=response_len,\n",
- " target=target,\n",
- " )\n",
- " )\n",
- " return cls(str(group_id), episodes)\n",
- "\n",
- "\n",
- "def collate(batches: list[list[Episode]]):\n",
- " inputs = []\n",
- " targets = []\n",
- " for batch in batches:\n",
- " request = [e.request_tensor for e in batch]\n",
- " request = torch.stack(request) # [b x s]\n",
- "\n",
- " response = [e.response_tensor for e in batch]\n",
- " response = torch.stack(response) # [b x s]\n",
- "\n",
- " ref_logprobs = [e.ref_logprobs for e in batch]\n",
- " ref_logprobs = torch.stack(ref_logprobs).squeeze() # [b x s]\n",
- "\n",
- " advantages = [e.advantage for e in batch]\n",
- " advantages = torch.tensor(advantages).unsqueeze(-1) # [b x 1]\n",
- "\n",
- " pad_id = batch[0].pad_id\n",
- " mask = response != pad_id\n",
- "\n",
- " input = {\"tokens\": torch.cat([request, response], dim=1)}\n",
- " target = {\n",
- " \"response\": response,\n",
- " \"ref_logprobs\": ref_logprobs,\n",
- " \"advantages\": advantages,\n",
- " \"padding_mask\": mask,\n",
- " }\n",
- " inputs.append(input)\n",
- " targets.append(target)\n",
- " return inputs, targets\n",
- "\n",
- "@dataclass\n",
- "class DatasetActor(ForgeActor):\n",
- " \"\"\"Actor wrapper for HuggingFace dataset to provide async interface.\"\"\"\n",
- "\n",
- " path: str = \"openai/gsm8k\"\n",
- " revision: str = \"main\"\n",
- " data_split: str = \"train\"\n",
- " streaming: bool = True\n",
- " model: str = \"Qwen/Qwen3-1.7B\"\n",
- "\n",
- " @endpoint\n",
- " def setup(self):\n",
- " self._tokenizer = get_tokenizer(self.model)\n",
- "\n",
- " def gsm8k_transform(sample):\n",
- " system_prompt = \"\"\"\n",
- " Put all your scratchpad work between and tags.\n",
- " Your final answer should be between and tags otherwise it will not be scored.\n",
- " \"\"\"\n",
- " request: str = sample[\"question\"]\n",
- " as_chat = [\n",
- " {\"role\": \"system\", \"content\": system_prompt},\n",
- " {\"role\": \"user\", \"content\": request},\n",
- " ]\n",
- " formatted_request = self._tokenizer.apply_chat_template(\n",
- " as_chat,\n",
- " tokenize=False,\n",
- " add_generation_prompt=True,\n",
- " )\n",
- " target: str = sample[\"answer\"]\n",
- " formatted_target = target.split(\"#### \")[1]\n",
- " return {\"request\": formatted_request, \"target\": formatted_target}\n",
- "\n",
- " ds = load_dataset(\n",
- " self.path, self.revision, split=self.data_split, streaming=self.streaming\n",
- " )\n",
- " ds = ds.map(gsm8k_transform)\n",
- " ds = ds.shuffle()\n",
- " self._iterator = iter(ds)\n",
- "\n",
- " @endpoint\n",
- " async def sample(self) -> dict[str, str] | None:\n",
- " try:\n",
- " sample = next(self._iterator)\n",
- "\n",
- " # Record dataset metrics\n",
- " record_metric(\"dataset/sample/count_samples_generated\", 1, Reduce.SUM)\n",
- " record_metric(\n",
- " \"dataset/sample/avg_sample_len\",\n",
- " len(sample[\"request\"]),\n",
- " Reduce.MEAN,\n",
- " )\n",
- "\n",
- " return sample\n",
- " except StopIteration:\n",
- " return None\n",
- "\n",
- " @endpoint\n",
- " async def pad_token(self):\n",
- " return self._tokenizer.pad_token_id"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "901b3d1d-7eba-4464-b881-48c11ff6e0ef",
- "metadata": {},
- "source": [
- "## Define loss"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "934aca32-0953-4945-9f99-e7b34804443b",
- "metadata": {},
- "outputs": [],
- "source": [
- "def simple_grpo_loss(\n",
- " logits: torch.Tensor,\n",
- " response: torch.Tensor,\n",
- " ref_logprobs: torch.Tensor,\n",
- " advantages: torch.Tensor,\n",
- " padding_mask: torch.Tensor,\n",
- " beta: float = 0.1,\n",
- ") -> torch.Tensor:\n",
- " \"\"\"\n",
- " Example GRPO Loss Function for RLTrainer\n",
- " \"\"\"\n",
- " logprobs: torch.Tensor = compute_logprobs(logits, response)\n",
- "\n",
- " # Note: This is also available in losses.grpo_loss via `SimpleGRPOLoss`\n",
- " kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1\n",
- " per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages\n",
- " per_token_loss = -(per_token_policy_loss - beta * kl)\n",
- " loss = (\n",
- " ((per_token_loss * padding_mask).sum(dim=1))\n",
- " / (padding_mask.sum(dim=1).clamp(min=1.0))\n",
- " ).mean()\n",
- " return loss"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "d4f8bbe3-b7ac-4905-b197-f10990f9a104",
- "metadata": {},
- "source": [
- "## Define Reward"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "163e98bf-e0f5-4ec3-9690-9839e687f9b3",
- "metadata": {},
- "outputs": [],
- "source": [
- "@dataclass\n",
- "class RewardActor(ForgeActor):\n",
- " \"\"\"Reward actor that uses a list of scoring functions.\"\"\"\n",
- "\n",
- " reward_functions: list[Callable]\n",
- "\n",
- " @endpoint\n",
- " async def evaluate_response(self, prompt: str, response: str, target: str) -> float:\n",
- " total_rewards = 0.0\n",
- " for reward_fn in self.reward_functions:\n",
- " reward = reward_fn(prompt, response, target)\n",
- " total_rewards += reward\n",
- "\n",
- " # Get a name for the reward function (works for classes, functions, lambdas)\n",
- " reward_fn_name = getattr(\n",
- " reward_fn, \"__name__\", reward_fn.__class__.__name__\n",
- " )\n",
- " # per function reward\n",
- " record_metric(\n",
- " f\"reward/evaluate_response/sum_{reward_fn_name}_reward\",\n",
- " reward,\n",
- " Reduce.SUM,\n",
- " )\n",
- " record_metric(\n",
- " f\"reward/evaluate_response/avg_{reward_fn_name}_reward\",\n",
- " reward,\n",
- " Reduce.MEAN,\n",
- " )\n",
- " record_metric(\n",
- " f\"reward/evaluate_response/std_{reward_fn_name}_reward\",\n",
- " reward,\n",
- " Reduce.STD,\n",
- " )\n",
- "\n",
- " # avg total reward\n",
- " record_metric(\n",
- " \"reward/evaluate_response/avg_total_reward\",\n",
- " reward,\n",
- " Reduce.MEAN,\n",
- " )\n",
- "\n",
- " # count fn calls\n",
- " record_metric(\n",
- " f\"reward/evaluate_response/count_{reward_fn_name}_calls\",\n",
- " 1,\n",
- " Reduce.SUM,\n",
- " )\n",
- "\n",
- " avg_reward = total_rewards / len(self.reward_functions)\n",
- " return avg_reward\n",
- "\n",
- "\n",
- "@dataclass\n",
- "class ComputeAdvantages(ForgeActor):\n",
- " \"\"\"Compute advantages for GRPO using reward signals.\"\"\"\n",
- "\n",
- " @endpoint\n",
- " async def compute(self, group: Group) -> list[float]:\n",
- " # TODO: add batch processing\n",
- " rewards = torch.tensor([[e.reward for e in group.episodes]])\n",
- " mean = rewards.mean(1, keepdim=True)\n",
- " std = rewards.std(1, keepdim=True)\n",
- " advantages = (rewards - mean) / (std + 1e-4)\n",
- " return advantages.squeeze(0).tolist()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "88523484-b414-41db-bd3f-0d8dbf881a85",
- "metadata": {},
- "outputs": [],
- "source": [
- "async def drop_weights(version: int):\n",
- " print(f\"Dropping weights @ version {version}\")\n",
- " start_time = time.perf_counter()\n",
- " prefix = get_param_prefix(version)\n",
- " matching_keys = await ts.keys(prefix)\n",
- " # TODO: once we have something like `get_meta()` in torchstore, we can just\n",
- " # query the type of the object instead of relying on keys.\n",
- " dcp_key = get_dcp_whole_state_dict_key(version)\n",
- " if dcp_key in matching_keys:\n",
- " dcp_handle = await ts.get(dcp_key)\n",
- " dcp_handle.drop()\n",
- " for key in matching_keys:\n",
- " await ts.delete(key)\n",
- " elapsed = time.perf_counter() - start_time\n",
- " print(f\"Dropped weights @ version {version}, took {elapsed:.2f} seconds\")"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "95d4fef3-180b-4b7e-8871-ecbe113cde72",
- "metadata": {},
- "source": [
- "## Setup Services"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "6c811974-cd6b-40ed-a179-4511a7a6c489",
- "metadata": {
- "scrolled": true
- },
- "outputs": [],
- "source": [
- "from omegaconf import OmegaConf\n",
- "from forge.cli.config import resolve_hf_hub_paths\n",
- "\n",
- "cfg = OmegaConf.load('apps/grpo/qwen3_1_7b.yaml')\n",
- "cfg = resolve_hf_hub_paths(cfg)\n",
- "OmegaConf.resolve(cfg)\n",
- "\n",
- "group_size = cfg.group_size # 8\n",
- "max_req_tokens = cfg.max_req_tokens # 512\n",
- "max_res_tokens = cfg.max_res_tokens # 512\n",
- "\n",
- "metric_logging_cfg = cfg.get(\"metric_logging\", {\"console\": {\"log_per_rank\": False}})\n",
- "mlogger = await get_or_create_metric_logger()\n",
- "await mlogger.init_backends.call_one(metric_logging_cfg)\n",
- "await ts.initialize(strategy=ts.ControllerStorageVolumes())\n",
- "\n",
- "dataloader, policy, trainer, replay_buffer, compute_advantages, ref_model, reward_actor = await asyncio.gather(\n",
- " DatasetActor.options(**cfg.actors.dataset).as_actor(**cfg.dataset),\n",
- " Policy.options(**cfg.services.policy).as_service(**cfg.policy),\n",
- " RLTrainer.options(**cfg.actors.trainer).as_actor(\n",
- " **cfg.trainer, loss=simple_grpo_loss\n",
- " ),\n",
- " ReplayBuffer.options(**cfg.actors.replay_buffer).as_actor(\n",
- " **cfg.replay_buffer, collate=collate\n",
- " ),\n",
- " ComputeAdvantages.options(**cfg.actors.compute_advantages).as_actor(),\n",
- " ReferenceModel.options(**cfg.services.ref_model).as_service(**cfg.ref_model),\n",
- " RewardActor.options(**cfg.services.reward_actor).as_service(\n",
- " reward_functions=[MathReward(), ThinkingReward()]\n",
- " ),\n",
- ")"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "3f2a305f-b1e2-4eac-803c-71bf3225fed7",
- "metadata": {},
- "source": [
- "## Rollout Loop"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "c1c676fb-2cd6-4c2c-87d4-e1b8cd0b87af",
- "metadata": {},
- "outputs": [],
- "source": [
- "async def continuous_rollouts():\n",
- " rollout_count = 0\n",
- " pad_id = await dataloader.pad_token.call_one()\n",
- " while True:\n",
- " t = Tracer(\"main_perf/continuous_rollouts\")\n",
- " t.start()\n",
- " sample = await dataloader.sample.call_one()\n",
- " if sample is None:\n",
- " print(\"Dataloader is empty, exiting continuous rollout\")\n",
- " return\n",
- "\n",
- " t.step(\"data_loading\")\n",
- "\n",
- " prompt, target = sample[\"request\"], sample[\"target\"]\n",
- " responses = await policy.generate.route(prompt)\n",
- " # TODO: this shall be part of the responses metadata instead of a separate call\n",
- " version = await policy.get_version.route()\n",
- "\n",
- " t.step(\"policy_generation\")\n",
- "\n",
- " assert (\n",
- " len(responses) > 0\n",
- " ), \"Sanity check: Responses should NEVER return empty\"\n",
- " assert (\n",
- " version := responses[0].generator_version\n",
- " ) is not None, \"Response must indicate a version\"\n",
- " group = Group.new_group(\n",
- " group_id=rollout_count,\n",
- " group_size=group_size,\n",
- " request=prompt,\n",
- " policy_version=version,\n",
- " pad_id=pad_id,\n",
- " request_len=max_req_tokens,\n",
- " response_len=max_res_tokens,\n",
- " target=target,\n",
- " )\n",
- "\n",
- " input_ids = torch.ones(\n",
- " (group_size, max_req_tokens + max_res_tokens),\n",
- " dtype=torch.long,\n",
- " device=\"cuda\",\n",
- " )\n",
- " # Populate episode info and calculate rewards\n",
- " for i, (episode, response) in enumerate(zip(group.episodes, responses)):\n",
- " episode.request_tokens = response.prompt_ids\n",
- " episode.response_tokens = response.token_ids\n",
- " episode.response = response.text\n",
- " input_ids[i, :max_req_tokens] = episode.request_tensor\n",
- " input_ids[i, max_req_tokens:] = episode.response_tensor\n",
- " episode.reward = await reward_actor.evaluate_response.route(\n",
- " prompt=prompt, response=response.text, target=target\n",
- " )\n",
- "\n",
- " t.step(\"reward_evaluation\")\n",
- "\n",
- " ref_logprobs = await ref_model.forward.route(\n",
- " input_ids, max_req_tokens, return_logprobs=True\n",
- " )\n",
- " t.step(\"reference_model_calculate_logprobs\")\n",
- "\n",
- " for i, episode in enumerate(group.episodes):\n",
- " episode.ref_logprobs = ref_logprobs[i]\n",
- " del ref_logprobs, input_ids\n",
- " t.step(\"compute_logprobs\")\n",
- "\n",
- " # Calculate advantages and add to replay buffer\n",
- " advantages = await compute_advantages.compute.call_one(group)\n",
- " for episode, advantage in zip(group.episodes, advantages):\n",
- " episode.advantage = advantage\n",
- " await replay_buffer.add.call_one(episode)\n",
- "\n",
- " # Log metrics\n",
- " rollout_count += 1\n",
- " record_metric(\n",
- " \"main/continuous_rollouts/count_rollout_iterations\", 1, Reduce.SUM\n",
- " )\n",
- " t.stop()"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "57c316dc-11b5-48ea-8b03-e1bb9d9d1f2b",
- "metadata": {},
- "source": [
- "## Training Loop"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "916a0e79-aded-4ee3-b1a8-db0e772996c9",
- "metadata": {},
- "outputs": [],
- "source": [
- "async def continuous_training():\n",
- " training_step = 0\n",
- " restart_tracer = True # Flag to control when to restart tracer\n",
- " while True:\n",
- " # Restart tracer when needed (initial start or after completing a training step)\n",
- " # Otherwise, we cannot measure time waiting for buffer\n",
- " if restart_tracer:\n",
- " t = Tracer(\"main_perf/continuous_training\")\n",
- " t.start()\n",
- " restart_tracer = False\n",
- "\n",
- " batch = await replay_buffer.sample.call_one(\n",
- " curr_policy_version=training_step\n",
- " )\n",
- " if batch is None:\n",
- " await asyncio.sleep(0.1)\n",
- " else:\n",
- " t.step(\"waiting_for_buffer\")\n",
- "\n",
- " inputs, targets = batch\n",
- " await trainer.train_step.call(inputs, targets)\n",
- " training_step += 1\n",
- " t.step(\"train_step\")\n",
- "\n",
- " await trainer.push_weights.call(training_step)\n",
- " t.step(\"push_weights\")\n",
- "\n",
- " await policy.update_weights.fanout(training_step)\n",
- " update_task = asyncio.create_task(policy.update_weights.fanout(training_step))\n",
- " t.step(\"update_weights\")\n",
- "\n",
- " if training_step >= 2:\n",
- " await drop_weights(training_step - 1)\n",
- " t.step(\"drop_weights\")\n",
- "\n",
- " t.stop()\n",
- " restart_tracer = True\n",
- "\n",
- " # Flush metrics every training step to WandB\n",
- " await mlogger.flush.call_one(training_step)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "4542863b-59c5-40bc-896c-6d8d44ada00f",
- "metadata": {},
- "source": [
- "## Run"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "58194c13-b75e-405d-ab11-18cbe1874d92",
- "metadata": {
- "scrolled": true
- },
- "outputs": [],
- "source": [
- "num_rollout_threads = 1\n",
- "num_training_threads = 1\n",
- "\n",
- "rollout_tasks = [\n",
- " asyncio.create_task(continuous_rollouts()) for _ in range(num_rollout_threads)\n",
- "]\n",
- "training_task = asyncio.create_task(continuous_training())\n",
- "\n",
- "try:\n",
- " await asyncio.gather(*rollout_tasks, training_task)\n",
- "except KeyboardInterrupt:\n",
- " print(\"Training interrupted by user\")\n",
- " for rollout_task in rollout_tasks:\n",
- " rollout_task.cancel()\n",
- " training_task.cancel()"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "b4603b80-1f25-49a1-920e-d24f38dfc687",
- "metadata": {},
- "source": [
- "## Shutdown"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "0d74e781-c253-4bd0-929f-bd4ad516ba81",
- "metadata": {},
- "outputs": [],
- "source": [
- "await mlogger.shutdown.call_one()\n",
- "await asyncio.sleep(2)\n",
- "\n",
- "await asyncio.gather(\n",
- " DatasetActor.shutdown(dataloader),\n",
- " policy.shutdown(),\n",
- " RLTrainer.shutdown(trainer),\n",
- " ReplayBuffer.shutdown(replay_buffer),\n",
- " ComputeAdvantages.shutdown(compute_advantages),\n",
- " ref_model.shutdown(),\n",
- " reward_actor.shutdown(),\n",
- ")\n",
- "await shutdown()"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "forge",
- "language": "python",
- "name": "forge"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.10.18"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
+ "nbformat": 4,
+ "nbformat_minor": 2
}
diff --git a/assets/versions.sh b/assets/versions.sh
index 49a755dc0..4defaa2a6 100644
--- a/assets/versions.sh
+++ b/assets/versions.sh
@@ -14,6 +14,6 @@ PYTORCH_VERSION="2.9.0.dev20250905"
VLLM_BRANCH="v0.10.0"
# Commit hashes
-MONARCH_COMMIT="195503223b5c2896846171f60ac99dc6868f8f2c"
+MONARCH_COMMIT="main"
TORCHTITAN_COMMIT="d0e25450bcac2332359b13fbda430dc701f073d4"
TORCHSTORE_COMMIT="662299faf4fd50ee30bd9aa3f4ce8c0e2db1d310"
diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py
index 6c2efd5e6..a618e8805 100644
--- a/src/forge/actors/generator.py
+++ b/src/forge/actors/generator.py
@@ -154,7 +154,7 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
worker_procs = await get_proc_mesh(process_config=process_config)
# Then, grab a single host from the workers...
- host_mesh = await host_mesh_from_proc(worker_procs)
+ host_mesh = await host_mesh_from_proc(worker_procs._uid)
singleton_slice = {k: slice(0, 1) for k in host_mesh.extent.keys()}
host_mesh = host_mesh.slice(**singleton_slice)
diff --git a/src/forge/controller/__init__.py b/src/forge/controller/__init__.py
index a579200e9..992eb71a6 100644
--- a/src/forge/controller/__init__.py
+++ b/src/forge/controller/__init__.py
@@ -5,9 +5,9 @@
# LICENSE file in the root directory of this source tree.
from .actor import ForgeActor
from .provisioner import (
+ get_or_create_provisioner,
get_proc_mesh,
host_mesh_from_proc,
- init_provisioner,
shutdown,
stop_proc_mesh,
)
@@ -16,7 +16,7 @@
"ForgeActor",
"get_proc_mesh",
"stop_proc_mesh",
- "init_provisioner",
+ "get_or_create_provisioner",
"shutdown",
"host_mesh_from_proc",
]
diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py
index cb36b2568..32e1ef2a3 100644
--- a/src/forge/controller/provisioner.py
+++ b/src/forge/controller/provisioner.py
@@ -16,7 +16,14 @@
from monarch._src.actor.actor_mesh import ActorMesh
from monarch._src.actor.shape import Extent
-from monarch.actor import Actor, endpoint, HostMesh, ProcMesh, this_host
+from monarch.actor import (
+ Actor,
+ endpoint,
+ get_or_spawn_controller,
+ HostMesh,
+ ProcMesh,
+ this_host,
+)
from monarch.tools import commands
@@ -95,7 +102,7 @@ def release_gpus(self, gpu_ids: list[str]) -> None:
self.available_gpus.add(int(gpu_id))
-class Provisioner:
+class Provisioner(Actor):
"""A global resource provisioner."""
def __init__(self, cfg: ProvisionerConfig | None = None):
@@ -138,11 +145,13 @@ def __init__(self, cfg: ProvisionerConfig | None = None):
self._registered_actors: list["ForgeActor"] = []
self._registered_services: list["ServiceInterface"] = []
+ @endpoint
async def initialize(self):
"""Call this after creating the instance"""
if self.launcher is not None:
await self.launcher.initialize()
+ @endpoint
async def create_host_mesh(self, name: str, num_hosts: int) -> HostMesh:
"""Creates a remote server and a HostMesh on it."""
# no need to lock here because this is already locked behind `get_proc_mesh`
@@ -172,6 +181,7 @@ async def create_host_mesh(self, name: str, num_hosts: int) -> HostMesh:
)
return host_mesh, server_name
+ @endpoint
def get_host_mesh(self, name: str) -> HostMesh:
"""Returns the host mesh given its associated name.
@@ -181,6 +191,7 @@ def get_host_mesh(self, name: str) -> HostMesh:
"""
return self._host_mesh_map[name]
+ @endpoint
async def get_proc_mesh(
self,
num_procs: int,
@@ -225,7 +236,7 @@ async def get_proc_mesh(
created_hosts = len(self._server_names)
mesh_name = f"alloc_{created_hosts}"
if host_mesh is None:
- host_mesh, server_name = await self.create_host_mesh(
+ host_mesh, server_name = await self.create_host_mesh.call_one(
name=mesh_name,
num_hosts=num_hosts,
)
@@ -283,6 +294,11 @@ def bootstrap(env: dict[str, str]):
per_host={"procs": num_procs},
bootstrap=functools.partial(bootstrap, env=env_vars),
)
+ uid = str(uuid.uuid4())
+ # Generate a unique ID to map procmesh to hostmesh
+ procs._uid = uid
+ print(f"Allocating procmesh with uid={uid}")
+ print(f"Allocating procs._uid: {procs._uid}")
if with_gpus:
# Set up environment variables for PyTorch distributed...
@@ -308,7 +324,7 @@ def bootstrap(env: dict[str, str]):
self._server_names.append(server_name)
self._proc_server_map[procs] = server_name
- self._proc_host_map[procs] = host_mesh
+ self._proc_host_map[uid] = host_mesh
# Spawn LocalFetcherActor for this ProcMesh and register with GlobalLoggingActor.
# When called, the LocalFetcherActor is broadcast by Monarch to all ranks in the ProcMesh.
@@ -316,18 +332,27 @@ def bootstrap(env: dict[str, str]):
from forge.observability.metric_actors import get_or_create_metric_logger
_ = await get_or_create_metric_logger(procs, process_name=mesh_name)
- return procs
- async def host_mesh_from_proc(self, proc_mesh: ProcMesh):
- if proc_mesh not in self._proc_host_map:
+ print(f"Returning procmesh with uid={uid}")
+ print(f"Returning procs._uid: {procs._uid}")
+ return procs, uid
+
+ @endpoint
+ async def host_mesh_from_proc(self, uid: str | None):
+ # uid: str | None = getattr(proc_mesh, "_uid", None)
+ print(f"self._proc_host_map: {self._proc_host_map}")
+ print(f"proc_mesh._uid: {uid}")
+ if uid is None or uid not in self._proc_host_map:
raise ValueError(
"The proc mesh was not allocated with an associated hostmesh."
)
- return self._proc_host_map[proc_mesh]
+ return self._proc_host_map[uid]
+ @endpoint
async def stop_proc_mesh(self, proc_mesh: ProcMesh):
"""Stops a proc mesh."""
- if proc_mesh not in self._proc_host_map:
+ uid: str | None = getattr(proc_mesh, "_uid", None)
+ if uid is None or uid not in self._proc_host_map:
logger.warning(
f"proc mesh {proc_mesh} was requested to be stopped, but was either already stopped or "
"was never registered with the provisioner."
@@ -350,8 +375,9 @@ async def stop_proc_mesh(self, proc_mesh: ProcMesh):
if proc_mesh in self._proc_server_map:
server_name = self._proc_server_map[proc_mesh]
commands.kill(server_name)
- del self._proc_host_map[proc_mesh]
+ del self._proc_host_map[uid]
+ @endpoint
def register_service(self, service: "ServiceInterface") -> None:
"""Registers a service allocation for cleanup."""
# Import ServiceInterface here instead of at top-level to avoid circular import
@@ -364,6 +390,7 @@ def register_service(self, service: "ServiceInterface") -> None:
self._registered_services.append(service)
+ @endpoint
def register_actor(self, actor: "ForgeActor") -> None:
"""Registers a single actor allocation for cleanup."""
@@ -372,13 +399,15 @@ def register_actor(self, actor: "ForgeActor") -> None:
self._registered_actors.append(actor)
+ @endpoint
async def shutdown_all_allocations(self):
"""Gracefully shut down all tracked actors and services."""
+ global _global_registered_services
logger.info(
- f"Shutting down {len(self._registered_services)} service(s) and {len(self._registered_actors)} actor(s)..."
+ f"Shutting down {len(_global_registered_services)} service(s) and {len(self._registered_actors)} actor(s)..."
)
# --- ServiceInterface ---
- for service in reversed(self._registered_services):
+ for service in reversed(_global_registered_services):
try:
await service.shutdown()
@@ -398,29 +427,30 @@ async def shutdown_all_allocations(self):
self._registered_actors.clear()
self._registered_services.clear()
+ @endpoint
async def shutdown(self):
"""Tears down all remaining remote allocations."""
- await self.shutdown_all_allocations()
+ await self.shutdown_all_allocations.call_one()
async with self._lock:
for server_name in self._server_names:
commands.kill(server_name)
-_provisioner: Provisioner | None = None
-
+_global_provisioner: Provisioner | None = None
+_global_registered_services: list["ServiceInterface"] = []
-async def init_provisioner(cfg: ProvisionerConfig | None = None):
- global _provisioner
- if not _provisioner:
- _provisioner = Provisioner(cfg)
- await _provisioner.initialize()
- return _provisioner
-
-async def _get_provisioner():
- if not _provisioner:
- await init_provisioner()
- return _provisioner
+async def get_or_create_provisioner(
+ cfg: ProvisionerConfig | None = None,
+) -> Provisioner:
+ """Gets or spawns the global Provisioner controller actor."""
+ global _global_provisioner
+ if _global_provisioner is None:
+ _global_provisioner = await get_or_spawn_controller(
+ "provisioner_controller", Provisioner, cfg
+ )
+ await _global_provisioner.initialize.call_one()
+ return _global_provisioner
async def get_proc_mesh(
@@ -445,8 +475,8 @@ async def get_proc_mesh(
A proc mesh.
"""
- provisioner = await _get_provisioner()
- return await provisioner.get_proc_mesh(
+ provisioner = await get_or_create_provisioner()
+ procs, uid = await provisioner.get_proc_mesh.call_one(
num_procs=process_config.procs,
with_gpus=process_config.with_gpus,
num_hosts=process_config.hosts,
@@ -456,34 +486,39 @@ async def get_proc_mesh(
port=port,
addr=addr,
)
+ setattr(procs, "_uid", uid)
+ print(f"Setting procs._uid: {procs._uid}")
+ return procs
-async def host_mesh_from_proc(proc_mesh: ProcMesh):
+async def host_mesh_from_proc(uid: str | None):
"""Returns the host mesh that allocated the original proc_mesh.
This functionality will be enabled in Monarch, so this is a temporary
API.
"""
- provisioner = await _get_provisioner()
- return await provisioner.host_mesh_from_proc(proc_mesh)
+ provisioner = await get_or_create_provisioner()
+ return await provisioner.host_mesh_from_proc.call_one(uid)
async def register_service(service: "ServiceInterface") -> None:
"""Registers a service allocation with the global provisioner."""
- provisioner = await _get_provisioner()
- provisioner.register_service(service)
+
+ # TODO: This is a temporary hack. Change this back once Services are actors
+ global _global_registered_services
+ _global_registered_services.append(service)
async def register_actor(actor: "ForgeActor") -> None:
"""Registers an actor allocation with the global provisioner."""
- provisioner = await _get_provisioner()
- provisioner.register_actor(actor)
+ provisioner = await get_or_create_provisioner()
+ await provisioner.register_actor.call_one(actor)
async def stop_proc_mesh(proc_mesh: ProcMesh):
- provisioner = await _get_provisioner()
- return await provisioner.stop_proc_mesh(proc_mesh=proc_mesh)
+ provisioner = await get_or_create_provisioner()
+ return await provisioner.stop_proc_mesh.call_one(proc_mesh=proc_mesh)
async def shutdown_metric_logger():
@@ -504,8 +539,8 @@ async def shutdown():
logger.info("Shutting down provisioner..")
- provisioner = await _get_provisioner()
- result = await provisioner.shutdown()
+ provisioner = await get_or_create_provisioner()
+ result = await provisioner.shutdown.call_one()
logger.info("Shutdown completed successfully")
return result
diff --git a/test.py b/test.py
new file mode 100644
index 000000000..680c2007a
--- /dev/null
+++ b/test.py
@@ -0,0 +1,37 @@
+# Copyright (c) Meta Platforms, Inc.
+# All rights reserved.
+#
+# Minimal repro: Provisioner host_mesh_from_proc() UID mapping bug
+#
+# Run this with:
+# python -m forge.tests.test_provisioner_uid_mapping
+
+import asyncio
+
+# import pytest
+
+from forge.controller.provisioner import (
+ get_or_create_provisioner,
+ get_proc_mesh,
+ stop_proc_mesh,
+)
+from forge.types import ProcessConfig
+
+
+# @pytest.mark.asyncio
+async def test_provisioner_host_mesh_lookup_uid_mapping():
+ prov = await get_or_create_provisioner()
+ pm = await get_proc_mesh(
+ ProcessConfig(procs=1, with_gpus=False, hosts=None, mesh_name="uid_repro")
+ )
+ # UID is attached locally by the helper
+ assert hasattr(pm, "_uid") and pm._uid, "missing _uid on returned ProcMesh"
+ print(f"✅ got ProcMesh with UID {pm._uid}")
+ hm = await prov.host_mesh_from_proc.call_one(pm._uid) # if pass pm, _uid is None
+ assert hm is not None
+ await stop_proc_mesh(pm)
+ print("✅ repro passed")
+
+
+if __name__ == "__main__":
+ asyncio.run(test_provisioner_host_mesh_lookup_uid_mapping())
diff --git a/tests/integration_tests/test_policy_update.py b/tests/integration_tests/test_policy_update.py
index 01f01a390..2d08cbe87 100644
--- a/tests/integration_tests/test_policy_update.py
+++ b/tests/integration_tests/test_policy_update.py
@@ -17,7 +17,7 @@
from forge.actors.generator import Generator
from forge.actors.trainer import RLTrainer
-from forge.controller.provisioner import init_provisioner
+from forge.controller.provisioner import get_or_create_provisioner
from forge.controller.service.service import uuid
from forge.types import LauncherConfig, ProvisionerConfig
@@ -194,7 +194,7 @@ async def _setup_and_teardown(request):
logger.info(f"`trainer.use_dcp` is overriden to {use_dcp_override}")
if cfg.get("provisioner", None) is not None:
- await init_provisioner(
+ await get_or_create_provisioner(
ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner))
)
await ts.initialize(strategy=ts.ControllerStorageVolumes())
diff --git a/tests/sandbox/rl_trainer/main.py b/tests/sandbox/rl_trainer/main.py
index 55714c49d..34dd13107 100644
--- a/tests/sandbox/rl_trainer/main.py
+++ b/tests/sandbox/rl_trainer/main.py
@@ -12,7 +12,7 @@
import torchstore as ts
from forge.actors.trainer import RLTrainer
from forge.controller.launcher import JOB_NAME_KEY, LAUNCHER_KEY
-from forge.controller.provisioner import init_provisioner, shutdown
+from forge.controller.provisioner import get_or_create_provisioner, shutdown
from forge.observability.metric_actors import get_or_create_metric_logger
from forge.observability.perf_tracker import Tracer
from forge.types import (
@@ -164,7 +164,7 @@ async def main(cfg: DictConfig):
trainer_dp_degree = cfg.trainer.parallelism.get("data_parallel_shard_degree", 1)
dp_size = trainer_dp_degree if trainer_dp_degree != -1 else 1
- await init_provisioner(
+ await get_or_create_provisioner(
ProvisionerConfig(
launcher_config=LauncherConfig(
launcher=cfg.get(LAUNCHER_KEY, Launcher.SLURM.value),
diff --git a/tests/sandbox/vllm/main.py b/tests/sandbox/vllm/main.py
index 425352340..2c1a0d5e4 100644
--- a/tests/sandbox/vllm/main.py
+++ b/tests/sandbox/vllm/main.py
@@ -15,7 +15,7 @@
from forge.actors.generator import Generator
-from forge.controller.provisioner import init_provisioner, shutdown
+from forge.controller.provisioner import get_or_create_provisioner, shutdown
from forge.data_models.completion import Completion
from forge.observability.metric_actors import get_or_create_metric_logger
@@ -29,7 +29,7 @@
async def run(cfg: DictConfig):
if cfg.get("provisioner", None) is not None:
- await init_provisioner(
+ await get_or_create_provisioner(
ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner))
)
metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}})