diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 1a9bad2ca..0069d0b17 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -47,3 +47,9 @@ repos:
hooks:
- id: pydoclint
args: [--config=pyproject.toml]
+
+- repo: https://github.com/fastai/nbdev.git
+ rev: 2.4.5
+ hooks:
+ - id: nbdev_clean
+ args: [--clear_all]
diff --git a/apps/grpo/notebook.ipynb b/apps/grpo/notebook.ipynb
new file mode 100644
index 000000000..8d9fbc75a
--- /dev/null
+++ b/apps/grpo/notebook.ipynb
@@ -0,0 +1,669 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "46c66f45-4be3-4674-a870-3849c1048ddb",
+ "metadata": {},
+ "source": [
+ "# GRPO for Math (GSM8k)\n",
+ "\n",
+ "## Import modules"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "97d9ca00-92a8-4bd3-9b2b-ab8856f5acce",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Copyright (c) Meta Platforms, Inc. and affiliates.\n",
+ "# All rights reserved.\n",
+ "#\n",
+ "# This source code is licensed under the BSD-style license found in the\n",
+ "# LICENSE file in the root directory of this source tree.\n",
+ "\n",
+ "import asyncio\n",
+ "import time\n",
+ "import uuid\n",
+ "from dataclasses import dataclass\n",
+ "from typing import Any, Callable\n",
+ "\n",
+ "import torch\n",
+ "import torch.nn.functional as F\n",
+ "import torchstore as ts\n",
+ "from datasets import load_dataset\n",
+ "from forge.actors._torchstore_utils import (\n",
+ " get_dcp_whole_state_dict_key,\n",
+ " get_param_prefix,\n",
+ ")\n",
+ "from forge.actors.generator import Generator as Policy\n",
+ "from forge.actors.reference_model import ReferenceModel\n",
+ "from forge.actors.replay_buffer import ReplayBuffer\n",
+ "from forge.actors.trainer import RLTrainer\n",
+ "from forge.cli.config import parse\n",
+ "from forge.controller.actor import ForgeActor\n",
+ "from forge.controller.provisioner import init_provisioner, shutdown\n",
+ "from forge.data.rewards import MathReward, ThinkingReward\n",
+ "from forge.observability.metric_actors import get_or_create_metric_logger\n",
+ "from forge.observability.metrics import record_metric, Reduce\n",
+ "from forge.observability.perf_tracker import Tracer\n",
+ "\n",
+ "from forge.types import LauncherConfig, ProvisionerConfig\n",
+ "from forge.util.ops import compute_logprobs\n",
+ "from monarch.actor import endpoint\n",
+ "from omegaconf import DictConfig\n",
+ "from vllm.transformers_utils.tokenizer import get_tokenizer\n",
+ "\n",
+ "import os\n",
+ "os.environ[\"MONARCH_HOSTMESH_V1\"] = \"1\"\n",
+ "os.environ[\"TORCHSTORE_RDMA_ENABLED\"] = \"1\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "34d4319f-e6c9-4f4b-9b92-c572de08f0b2",
+ "metadata": {},
+ "source": [
+ "## Define Data Structures"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "b4a25e9d-e1dd-4ea7-a80c-383a2c04656a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@dataclass\n",
+ "class Episode:\n",
+ " # TODO: add adtional layer for multi-turn\n",
+ " episode_id: str\n",
+ " request: str\n",
+ " policy_version: int\n",
+ " pad_id: int\n",
+ " request_len: int\n",
+ " response_len: int\n",
+ " target: Any | None = None\n",
+ " # processed data\n",
+ " response: str | None = None\n",
+ " request_tokens: list[int] | None = None\n",
+ " response_tokens: list[int] | None = None\n",
+ " ref_logprobs: torch.Tensor | None = None\n",
+ " reward: float | None = None\n",
+ " advantage: float | None = None\n",
+ "\n",
+ " @property\n",
+ " def request_tensor(self):\n",
+ " tensor = torch.tensor(self.request_tokens, dtype=torch.long)\n",
+ " if tensor.shape[0] < self.request_len: # left pad\n",
+ " diff = self.request_len - tensor.shape[0]\n",
+ " tensor = F.pad(tensor, (diff, 0), value=self.pad_id)\n",
+ " return tensor\n",
+ "\n",
+ " @property\n",
+ " def response_tensor(self):\n",
+ " tensor = torch.tensor(self.response_tokens, dtype=torch.long)\n",
+ " if tensor.shape[0] < self.response_len: # right pad\n",
+ " diff = self.response_len - tensor.shape[0]\n",
+ " tensor = F.pad(tensor, (0, diff), value=self.pad_id)\n",
+ " return tensor\n",
+ "\n",
+ "\n",
+ "@dataclass\n",
+ "class Group:\n",
+ " group_id: str\n",
+ " episodes: list[Episode]\n",
+ "\n",
+ " @classmethod\n",
+ " def new_group(\n",
+ " cls,\n",
+ " group_id: int,\n",
+ " group_size: int,\n",
+ " request: str,\n",
+ " policy_version: int,\n",
+ " pad_id: int,\n",
+ " request_len: int,\n",
+ " response_len: int,\n",
+ " target: Any = None,\n",
+ " ):\n",
+ " episodes = []\n",
+ " for _ in range(group_size):\n",
+ " episodes.append(\n",
+ " Episode(\n",
+ " episode_id=str(uuid.uuid4()),\n",
+ " request=request,\n",
+ " policy_version=policy_version,\n",
+ " pad_id=pad_id,\n",
+ " request_len=request_len,\n",
+ " response_len=response_len,\n",
+ " target=target,\n",
+ " )\n",
+ " )\n",
+ " return cls(str(group_id), episodes)\n",
+ "\n",
+ "\n",
+ "def collate(batches: list[list[Episode]]):\n",
+ " inputs = []\n",
+ " targets = []\n",
+ " for batch in batches:\n",
+ " request = [e.request_tensor for e in batch]\n",
+ " request = torch.stack(request) # [b x s]\n",
+ "\n",
+ " response = [e.response_tensor for e in batch]\n",
+ " response = torch.stack(response) # [b x s]\n",
+ "\n",
+ " ref_logprobs = [e.ref_logprobs for e in batch]\n",
+ " ref_logprobs = torch.stack(ref_logprobs).squeeze() # [b x s]\n",
+ "\n",
+ " advantages = [e.advantage for e in batch]\n",
+ " advantages = torch.tensor(advantages).unsqueeze(-1) # [b x 1]\n",
+ "\n",
+ " pad_id = batch[0].pad_id\n",
+ " mask = response != pad_id\n",
+ "\n",
+ " input = {\"tokens\": torch.cat([request, response], dim=1)}\n",
+ " target = {\n",
+ " \"response\": response,\n",
+ " \"ref_logprobs\": ref_logprobs,\n",
+ " \"advantages\": advantages,\n",
+ " \"padding_mask\": mask,\n",
+ " }\n",
+ " inputs.append(input)\n",
+ " targets.append(target)\n",
+ " return inputs, targets\n",
+ "\n",
+ "@dataclass\n",
+ "class DatasetActor(ForgeActor):\n",
+ " \"\"\"Actor wrapper for HuggingFace dataset to provide async interface.\"\"\"\n",
+ "\n",
+ " path: str = \"openai/gsm8k\"\n",
+ " revision: str = \"main\"\n",
+ " data_split: str = \"train\"\n",
+ " streaming: bool = True\n",
+ " model: str = \"Qwen/Qwen3-1.7B\"\n",
+ "\n",
+ " @endpoint\n",
+ " def setup(self):\n",
+ " self._tokenizer = get_tokenizer(self.model)\n",
+ "\n",
+ " def gsm8k_transform(sample):\n",
+ " system_prompt = \"\"\"\n",
+ " Put all your scratchpad work between and tags.\n",
+ " Your final answer should be between and tags otherwise it will not be scored.\n",
+ " \"\"\"\n",
+ " request: str = sample[\"question\"]\n",
+ " as_chat = [\n",
+ " {\"role\": \"system\", \"content\": system_prompt},\n",
+ " {\"role\": \"user\", \"content\": request},\n",
+ " ]\n",
+ " formatted_request = self._tokenizer.apply_chat_template(\n",
+ " as_chat,\n",
+ " tokenize=False,\n",
+ " add_generation_prompt=True,\n",
+ " )\n",
+ " target: str = sample[\"answer\"]\n",
+ " formatted_target = target.split(\"#### \")[1]\n",
+ " return {\"request\": formatted_request, \"target\": formatted_target}\n",
+ "\n",
+ " ds = load_dataset(\n",
+ " self.path, self.revision, split=self.data_split, streaming=self.streaming\n",
+ " )\n",
+ " ds = ds.map(gsm8k_transform)\n",
+ " ds = ds.shuffle()\n",
+ " self._iterator = iter(ds)\n",
+ "\n",
+ " @endpoint\n",
+ " async def sample(self) -> dict[str, str] | None:\n",
+ " try:\n",
+ " sample = next(self._iterator)\n",
+ "\n",
+ " # Record dataset metrics\n",
+ " record_metric(\"dataset/sample/count_samples_generated\", 1, Reduce.SUM)\n",
+ " record_metric(\n",
+ " \"dataset/sample/avg_sample_len\",\n",
+ " len(sample[\"request\"]),\n",
+ " Reduce.MEAN,\n",
+ " )\n",
+ "\n",
+ " return sample\n",
+ " except StopIteration:\n",
+ " return None\n",
+ "\n",
+ " @endpoint\n",
+ " async def pad_token(self):\n",
+ " return self._tokenizer.pad_token_id"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "901b3d1d-7eba-4464-b881-48c11ff6e0ef",
+ "metadata": {},
+ "source": [
+ "## Define loss"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "934aca32-0953-4945-9f99-e7b34804443b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def simple_grpo_loss(\n",
+ " logits: torch.Tensor,\n",
+ " response: torch.Tensor,\n",
+ " ref_logprobs: torch.Tensor,\n",
+ " advantages: torch.Tensor,\n",
+ " padding_mask: torch.Tensor,\n",
+ " beta: float = 0.1,\n",
+ ") -> torch.Tensor:\n",
+ " \"\"\"\n",
+ " Example GRPO Loss Function for RLTrainer\n",
+ " \"\"\"\n",
+ " logprobs: torch.Tensor = compute_logprobs(logits, response)\n",
+ "\n",
+ " # Note: This is also available in losses.grpo_loss via `SimpleGRPOLoss`\n",
+ " kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1\n",
+ " per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages\n",
+ " per_token_loss = -(per_token_policy_loss - beta * kl)\n",
+ " loss = (\n",
+ " ((per_token_loss * padding_mask).sum(dim=1))\n",
+ " / (padding_mask.sum(dim=1).clamp(min=1.0))\n",
+ " ).mean()\n",
+ " return loss"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d4f8bbe3-b7ac-4905-b197-f10990f9a104",
+ "metadata": {},
+ "source": [
+ "## Define Reward"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "163e98bf-e0f5-4ec3-9690-9839e687f9b3",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@dataclass\n",
+ "class RewardActor(ForgeActor):\n",
+ " \"\"\"Reward actor that uses a list of scoring functions.\"\"\"\n",
+ "\n",
+ " reward_functions: list[Callable]\n",
+ "\n",
+ " @endpoint\n",
+ " async def evaluate_response(self, prompt: str, response: str, target: str) -> float:\n",
+ " total_rewards = 0.0\n",
+ " for reward_fn in self.reward_functions:\n",
+ " reward = reward_fn(prompt, response, target)\n",
+ " total_rewards += reward\n",
+ "\n",
+ " # Get a name for the reward function (works for classes, functions, lambdas)\n",
+ " reward_fn_name = getattr(\n",
+ " reward_fn, \"__name__\", reward_fn.__class__.__name__\n",
+ " )\n",
+ " # per function reward\n",
+ " record_metric(\n",
+ " f\"reward/evaluate_response/sum_{reward_fn_name}_reward\",\n",
+ " reward,\n",
+ " Reduce.SUM,\n",
+ " )\n",
+ " record_metric(\n",
+ " f\"reward/evaluate_response/avg_{reward_fn_name}_reward\",\n",
+ " reward,\n",
+ " Reduce.MEAN,\n",
+ " )\n",
+ " record_metric(\n",
+ " f\"reward/evaluate_response/std_{reward_fn_name}_reward\",\n",
+ " reward,\n",
+ " Reduce.STD,\n",
+ " )\n",
+ "\n",
+ " # avg total reward\n",
+ " record_metric(\n",
+ " \"reward/evaluate_response/avg_total_reward\",\n",
+ " reward,\n",
+ " Reduce.MEAN,\n",
+ " )\n",
+ "\n",
+ " # count fn calls\n",
+ " record_metric(\n",
+ " f\"reward/evaluate_response/count_{reward_fn_name}_calls\",\n",
+ " 1,\n",
+ " Reduce.SUM,\n",
+ " )\n",
+ "\n",
+ " avg_reward = total_rewards / len(self.reward_functions)\n",
+ " return avg_reward\n",
+ "\n",
+ "\n",
+ "@dataclass\n",
+ "class ComputeAdvantages(ForgeActor):\n",
+ " \"\"\"Compute advantages for GRPO using reward signals.\"\"\"\n",
+ "\n",
+ " @endpoint\n",
+ " async def compute(self, group: Group) -> list[float]:\n",
+ " # TODO: add batch processing\n",
+ " rewards = torch.tensor([[e.reward for e in group.episodes]])\n",
+ " mean = rewards.mean(1, keepdim=True)\n",
+ " std = rewards.std(1, keepdim=True)\n",
+ " advantages = (rewards - mean) / (std + 1e-4)\n",
+ " return advantages.squeeze(0).tolist()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "88523484-b414-41db-bd3f-0d8dbf881a85",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "async def drop_weights(version: int):\n",
+ " print(f\"Dropping weights @ version {version}\")\n",
+ " start_time = time.perf_counter()\n",
+ " prefix = get_param_prefix(version)\n",
+ " matching_keys = await ts.keys(prefix)\n",
+ " # TODO: once we have something like `get_meta()` in torchstore, we can just\n",
+ " # query the type of the object instead of relying on keys.\n",
+ " dcp_key = get_dcp_whole_state_dict_key(version)\n",
+ " if dcp_key in matching_keys:\n",
+ " dcp_handle = await ts.get(dcp_key)\n",
+ " dcp_handle.drop()\n",
+ " for key in matching_keys:\n",
+ " await ts.delete(key)\n",
+ " elapsed = time.perf_counter() - start_time\n",
+ " print(f\"Dropped weights @ version {version}, took {elapsed:.2f} seconds\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "95d4fef3-180b-4b7e-8871-ecbe113cde72",
+ "metadata": {},
+ "source": [
+ "## Setup Services"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "6c811974-cd6b-40ed-a179-4511a7a6c489",
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "from omegaconf import OmegaConf\n",
+ "from forge.cli.config import resolve_hf_hub_paths\n",
+ "\n",
+ "cfg = OmegaConf.load('apps/grpo/qwen3_1_7b.yaml')\n",
+ "cfg = resolve_hf_hub_paths(cfg)\n",
+ "OmegaConf.resolve(cfg)\n",
+ "\n",
+ "group_size = cfg.group_size # 8\n",
+ "max_req_tokens = cfg.max_req_tokens # 512\n",
+ "max_res_tokens = cfg.max_res_tokens # 512\n",
+ "\n",
+ "metric_logging_cfg = cfg.get(\"metric_logging\", {\"console\": {\"log_per_rank\": False}})\n",
+ "mlogger = await get_or_create_metric_logger()\n",
+ "await mlogger.init_backends.call_one(metric_logging_cfg)\n",
+ "await ts.initialize(strategy=ts.ControllerStorageVolumes())\n",
+ "\n",
+ "dataloader, policy, trainer, replay_buffer, compute_advantages, ref_model, reward_actor = await asyncio.gather(\n",
+ " DatasetActor.options(**cfg.actors.dataset).as_actor(**cfg.dataset),\n",
+ " Policy.options(**cfg.services.policy).as_service(**cfg.policy),\n",
+ " RLTrainer.options(**cfg.actors.trainer).as_actor(\n",
+ " **cfg.trainer, loss=simple_grpo_loss\n",
+ " ),\n",
+ " ReplayBuffer.options(**cfg.actors.replay_buffer).as_actor(\n",
+ " **cfg.replay_buffer, collate=collate\n",
+ " ),\n",
+ " ComputeAdvantages.options(**cfg.actors.compute_advantages).as_actor(),\n",
+ " ReferenceModel.options(**cfg.services.ref_model).as_service(**cfg.ref_model),\n",
+ " RewardActor.options(**cfg.services.reward_actor).as_service(\n",
+ " reward_functions=[MathReward(), ThinkingReward()]\n",
+ " ),\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "3f2a305f-b1e2-4eac-803c-71bf3225fed7",
+ "metadata": {},
+ "source": [
+ "## Rollout Loop"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c1c676fb-2cd6-4c2c-87d4-e1b8cd0b87af",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "async def continuous_rollouts():\n",
+ " rollout_count = 0\n",
+ " pad_id = await dataloader.pad_token.call_one()\n",
+ " while True:\n",
+ " t = Tracer(\"main_perf/continuous_rollouts\")\n",
+ " t.start()\n",
+ " sample = await dataloader.sample.call_one()\n",
+ " if sample is None:\n",
+ " print(\"Dataloader is empty, exiting continuous rollout\")\n",
+ " return\n",
+ "\n",
+ " t.step(\"data_loading\")\n",
+ "\n",
+ " prompt, target = sample[\"request\"], sample[\"target\"]\n",
+ " responses = await policy.generate.route(prompt)\n",
+ " # TODO: this shall be part of the responses metadata instead of a separate call\n",
+ " version = await policy.get_version.route()\n",
+ "\n",
+ " t.step(\"policy_generation\")\n",
+ "\n",
+ " assert (\n",
+ " len(responses) > 0\n",
+ " ), \"Sanity check: Responses should NEVER return empty\"\n",
+ " assert (\n",
+ " version := responses[0].generator_version\n",
+ " ) is not None, \"Response must indicate a version\"\n",
+ " group = Group.new_group(\n",
+ " group_id=rollout_count,\n",
+ " group_size=group_size,\n",
+ " request=prompt,\n",
+ " policy_version=version,\n",
+ " pad_id=pad_id,\n",
+ " request_len=max_req_tokens,\n",
+ " response_len=max_res_tokens,\n",
+ " target=target,\n",
+ " )\n",
+ "\n",
+ " input_ids = torch.ones(\n",
+ " (group_size, max_req_tokens + max_res_tokens),\n",
+ " dtype=torch.long,\n",
+ " device=\"cuda\",\n",
+ " )\n",
+ " # Populate episode info and calculate rewards\n",
+ " for i, (episode, response) in enumerate(zip(group.episodes, responses)):\n",
+ " episode.request_tokens = response.prompt_ids\n",
+ " episode.response_tokens = response.token_ids\n",
+ " episode.response = response.text\n",
+ " input_ids[i, :max_req_tokens] = episode.request_tensor\n",
+ " input_ids[i, max_req_tokens:] = episode.response_tensor\n",
+ " episode.reward = await reward_actor.evaluate_response.route(\n",
+ " prompt=prompt, response=response.text, target=target\n",
+ " )\n",
+ "\n",
+ " t.step(\"reward_evaluation\")\n",
+ "\n",
+ " ref_logprobs = await ref_model.forward.route(\n",
+ " input_ids, max_req_tokens, return_logprobs=True\n",
+ " )\n",
+ " t.step(\"reference_model_calculate_logprobs\")\n",
+ "\n",
+ " for i, episode in enumerate(group.episodes):\n",
+ " episode.ref_logprobs = ref_logprobs[i]\n",
+ " del ref_logprobs, input_ids\n",
+ " t.step(\"compute_logprobs\")\n",
+ "\n",
+ " # Calculate advantages and add to replay buffer\n",
+ " advantages = await compute_advantages.compute.call_one(group)\n",
+ " for episode, advantage in zip(group.episodes, advantages):\n",
+ " episode.advantage = advantage\n",
+ " await replay_buffer.add.call_one(episode)\n",
+ "\n",
+ " # Log metrics\n",
+ " rollout_count += 1\n",
+ " record_metric(\n",
+ " \"main/continuous_rollouts/count_rollout_iterations\", 1, Reduce.SUM\n",
+ " )\n",
+ " t.stop()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "57c316dc-11b5-48ea-8b03-e1bb9d9d1f2b",
+ "metadata": {},
+ "source": [
+ "## Training Loop"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "916a0e79-aded-4ee3-b1a8-db0e772996c9",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "async def continuous_training():\n",
+ " training_step = 0\n",
+ " restart_tracer = True # Flag to control when to restart tracer\n",
+ " while True:\n",
+ " # Restart tracer when needed (initial start or after completing a training step)\n",
+ " # Otherwise, we cannot measure time waiting for buffer\n",
+ " if restart_tracer:\n",
+ " t = Tracer(\"main_perf/continuous_training\")\n",
+ " t.start()\n",
+ " restart_tracer = False\n",
+ "\n",
+ " batch = await replay_buffer.sample.call_one(\n",
+ " curr_policy_version=training_step\n",
+ " )\n",
+ " if batch is None:\n",
+ " await asyncio.sleep(0.1)\n",
+ " else:\n",
+ " t.step(\"waiting_for_buffer\")\n",
+ "\n",
+ " inputs, targets = batch\n",
+ " await trainer.train_step.call(inputs, targets)\n",
+ " training_step += 1\n",
+ " t.step(\"train_step\")\n",
+ "\n",
+ " await trainer.push_weights.call(training_step)\n",
+ " t.step(\"push_weights\")\n",
+ "\n",
+ " await policy.update_weights.fanout(training_step)\n",
+ " update_task = asyncio.create_task(policy.update_weights.fanout(training_step))\n",
+ " t.step(\"update_weights\")\n",
+ "\n",
+ " if training_step >= 2:\n",
+ " await drop_weights(training_step - 1)\n",
+ " t.step(\"drop_weights\")\n",
+ "\n",
+ " t.stop()\n",
+ " restart_tracer = True\n",
+ "\n",
+ " # Flush metrics every training step to WandB\n",
+ " await mlogger.flush.call_one(training_step)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "4542863b-59c5-40bc-896c-6d8d44ada00f",
+ "metadata": {},
+ "source": [
+ "## Run"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "58194c13-b75e-405d-ab11-18cbe1874d92",
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "num_rollout_threads = 1\n",
+ "num_training_threads = 1\n",
+ "\n",
+ "rollout_tasks = [\n",
+ " asyncio.create_task(continuous_rollouts()) for _ in range(num_rollout_threads)\n",
+ "]\n",
+ "training_task = asyncio.create_task(continuous_training())\n",
+ "\n",
+ "try:\n",
+ " await asyncio.gather(*rollout_tasks, training_task)\n",
+ "except KeyboardInterrupt:\n",
+ " print(\"Training interrupted by user\")\n",
+ " for rollout_task in rollout_tasks:\n",
+ " rollout_task.cancel()\n",
+ " training_task.cancel()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b4603b80-1f25-49a1-920e-d24f38dfc687",
+ "metadata": {},
+ "source": [
+ "## Shutdown"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "0d74e781-c253-4bd0-929f-bd4ad516ba81",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "await mlogger.shutdown.call_one()\n",
+ "await asyncio.sleep(2)\n",
+ "\n",
+ "await asyncio.gather(\n",
+ " DatasetActor.shutdown(dataloader),\n",
+ " policy.shutdown(),\n",
+ " RLTrainer.shutdown(trainer),\n",
+ " ReplayBuffer.shutdown(replay_buffer),\n",
+ " ComputeAdvantages.shutdown(compute_advantages),\n",
+ " ref_model.shutdown(),\n",
+ " reward_actor.shutdown(),\n",
+ ")\n",
+ "await shutdown()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "forge",
+ "language": "python",
+ "name": "forge"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.18"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}