diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1a9bad2ca..0069d0b17 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -47,3 +47,9 @@ repos: hooks: - id: pydoclint args: [--config=pyproject.toml] + +- repo: https://github.com/fastai/nbdev.git + rev: 2.4.5 + hooks: + - id: nbdev_clean + args: [--clear_all] diff --git a/apps/grpo/notebook.ipynb b/apps/grpo/notebook.ipynb new file mode 100644 index 000000000..8d9fbc75a --- /dev/null +++ b/apps/grpo/notebook.ipynb @@ -0,0 +1,669 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "46c66f45-4be3-4674-a870-3849c1048ddb", + "metadata": {}, + "source": [ + "# GRPO for Math (GSM8k)\n", + "\n", + "## Import modules" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "97d9ca00-92a8-4bd3-9b2b-ab8856f5acce", + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright (c) Meta Platforms, Inc. and affiliates.\n", + "# All rights reserved.\n", + "#\n", + "# This source code is licensed under the BSD-style license found in the\n", + "# LICENSE file in the root directory of this source tree.\n", + "\n", + "import asyncio\n", + "import time\n", + "import uuid\n", + "from dataclasses import dataclass\n", + "from typing import Any, Callable\n", + "\n", + "import torch\n", + "import torch.nn.functional as F\n", + "import torchstore as ts\n", + "from datasets import load_dataset\n", + "from forge.actors._torchstore_utils import (\n", + " get_dcp_whole_state_dict_key,\n", + " get_param_prefix,\n", + ")\n", + "from forge.actors.generator import Generator as Policy\n", + "from forge.actors.reference_model import ReferenceModel\n", + "from forge.actors.replay_buffer import ReplayBuffer\n", + "from forge.actors.trainer import RLTrainer\n", + "from forge.cli.config import parse\n", + "from forge.controller.actor import ForgeActor\n", + "from forge.controller.provisioner import init_provisioner, shutdown\n", + "from forge.data.rewards import MathReward, ThinkingReward\n", + "from forge.observability.metric_actors import get_or_create_metric_logger\n", + "from forge.observability.metrics import record_metric, Reduce\n", + "from forge.observability.perf_tracker import Tracer\n", + "\n", + "from forge.types import LauncherConfig, ProvisionerConfig\n", + "from forge.util.ops import compute_logprobs\n", + "from monarch.actor import endpoint\n", + "from omegaconf import DictConfig\n", + "from vllm.transformers_utils.tokenizer import get_tokenizer\n", + "\n", + "import os\n", + "os.environ[\"MONARCH_HOSTMESH_V1\"] = \"1\"\n", + "os.environ[\"TORCHSTORE_RDMA_ENABLED\"] = \"1\"" + ] + }, + { + "cell_type": "markdown", + "id": "34d4319f-e6c9-4f4b-9b92-c572de08f0b2", + "metadata": {}, + "source": [ + "## Define Data Structures" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b4a25e9d-e1dd-4ea7-a80c-383a2c04656a", + "metadata": {}, + "outputs": [], + "source": [ + "@dataclass\n", + "class Episode:\n", + " # TODO: add adtional layer for multi-turn\n", + " episode_id: str\n", + " request: str\n", + " policy_version: int\n", + " pad_id: int\n", + " request_len: int\n", + " response_len: int\n", + " target: Any | None = None\n", + " # processed data\n", + " response: str | None = None\n", + " request_tokens: list[int] | None = None\n", + " response_tokens: list[int] | None = None\n", + " ref_logprobs: torch.Tensor | None = None\n", + " reward: float | None = None\n", + " advantage: float | None = None\n", + "\n", + " @property\n", + " def request_tensor(self):\n", + " tensor = torch.tensor(self.request_tokens, dtype=torch.long)\n", + " if tensor.shape[0] < self.request_len: # left pad\n", + " diff = self.request_len - tensor.shape[0]\n", + " tensor = F.pad(tensor, (diff, 0), value=self.pad_id)\n", + " return tensor\n", + "\n", + " @property\n", + " def response_tensor(self):\n", + " tensor = torch.tensor(self.response_tokens, dtype=torch.long)\n", + " if tensor.shape[0] < self.response_len: # right pad\n", + " diff = self.response_len - tensor.shape[0]\n", + " tensor = F.pad(tensor, (0, diff), value=self.pad_id)\n", + " return tensor\n", + "\n", + "\n", + "@dataclass\n", + "class Group:\n", + " group_id: str\n", + " episodes: list[Episode]\n", + "\n", + " @classmethod\n", + " def new_group(\n", + " cls,\n", + " group_id: int,\n", + " group_size: int,\n", + " request: str,\n", + " policy_version: int,\n", + " pad_id: int,\n", + " request_len: int,\n", + " response_len: int,\n", + " target: Any = None,\n", + " ):\n", + " episodes = []\n", + " for _ in range(group_size):\n", + " episodes.append(\n", + " Episode(\n", + " episode_id=str(uuid.uuid4()),\n", + " request=request,\n", + " policy_version=policy_version,\n", + " pad_id=pad_id,\n", + " request_len=request_len,\n", + " response_len=response_len,\n", + " target=target,\n", + " )\n", + " )\n", + " return cls(str(group_id), episodes)\n", + "\n", + "\n", + "def collate(batches: list[list[Episode]]):\n", + " inputs = []\n", + " targets = []\n", + " for batch in batches:\n", + " request = [e.request_tensor for e in batch]\n", + " request = torch.stack(request) # [b x s]\n", + "\n", + " response = [e.response_tensor for e in batch]\n", + " response = torch.stack(response) # [b x s]\n", + "\n", + " ref_logprobs = [e.ref_logprobs for e in batch]\n", + " ref_logprobs = torch.stack(ref_logprobs).squeeze() # [b x s]\n", + "\n", + " advantages = [e.advantage for e in batch]\n", + " advantages = torch.tensor(advantages).unsqueeze(-1) # [b x 1]\n", + "\n", + " pad_id = batch[0].pad_id\n", + " mask = response != pad_id\n", + "\n", + " input = {\"tokens\": torch.cat([request, response], dim=1)}\n", + " target = {\n", + " \"response\": response,\n", + " \"ref_logprobs\": ref_logprobs,\n", + " \"advantages\": advantages,\n", + " \"padding_mask\": mask,\n", + " }\n", + " inputs.append(input)\n", + " targets.append(target)\n", + " return inputs, targets\n", + "\n", + "@dataclass\n", + "class DatasetActor(ForgeActor):\n", + " \"\"\"Actor wrapper for HuggingFace dataset to provide async interface.\"\"\"\n", + "\n", + " path: str = \"openai/gsm8k\"\n", + " revision: str = \"main\"\n", + " data_split: str = \"train\"\n", + " streaming: bool = True\n", + " model: str = \"Qwen/Qwen3-1.7B\"\n", + "\n", + " @endpoint\n", + " def setup(self):\n", + " self._tokenizer = get_tokenizer(self.model)\n", + "\n", + " def gsm8k_transform(sample):\n", + " system_prompt = \"\"\"\n", + " Put all your scratchpad work between and tags.\n", + " Your final answer should be between and tags otherwise it will not be scored.\n", + " \"\"\"\n", + " request: str = sample[\"question\"]\n", + " as_chat = [\n", + " {\"role\": \"system\", \"content\": system_prompt},\n", + " {\"role\": \"user\", \"content\": request},\n", + " ]\n", + " formatted_request = self._tokenizer.apply_chat_template(\n", + " as_chat,\n", + " tokenize=False,\n", + " add_generation_prompt=True,\n", + " )\n", + " target: str = sample[\"answer\"]\n", + " formatted_target = target.split(\"#### \")[1]\n", + " return {\"request\": formatted_request, \"target\": formatted_target}\n", + "\n", + " ds = load_dataset(\n", + " self.path, self.revision, split=self.data_split, streaming=self.streaming\n", + " )\n", + " ds = ds.map(gsm8k_transform)\n", + " ds = ds.shuffle()\n", + " self._iterator = iter(ds)\n", + "\n", + " @endpoint\n", + " async def sample(self) -> dict[str, str] | None:\n", + " try:\n", + " sample = next(self._iterator)\n", + "\n", + " # Record dataset metrics\n", + " record_metric(\"dataset/sample/count_samples_generated\", 1, Reduce.SUM)\n", + " record_metric(\n", + " \"dataset/sample/avg_sample_len\",\n", + " len(sample[\"request\"]),\n", + " Reduce.MEAN,\n", + " )\n", + "\n", + " return sample\n", + " except StopIteration:\n", + " return None\n", + "\n", + " @endpoint\n", + " async def pad_token(self):\n", + " return self._tokenizer.pad_token_id" + ] + }, + { + "cell_type": "markdown", + "id": "901b3d1d-7eba-4464-b881-48c11ff6e0ef", + "metadata": {}, + "source": [ + "## Define loss" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "934aca32-0953-4945-9f99-e7b34804443b", + "metadata": {}, + "outputs": [], + "source": [ + "def simple_grpo_loss(\n", + " logits: torch.Tensor,\n", + " response: torch.Tensor,\n", + " ref_logprobs: torch.Tensor,\n", + " advantages: torch.Tensor,\n", + " padding_mask: torch.Tensor,\n", + " beta: float = 0.1,\n", + ") -> torch.Tensor:\n", + " \"\"\"\n", + " Example GRPO Loss Function for RLTrainer\n", + " \"\"\"\n", + " logprobs: torch.Tensor = compute_logprobs(logits, response)\n", + "\n", + " # Note: This is also available in losses.grpo_loss via `SimpleGRPOLoss`\n", + " kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1\n", + " per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages\n", + " per_token_loss = -(per_token_policy_loss - beta * kl)\n", + " loss = (\n", + " ((per_token_loss * padding_mask).sum(dim=1))\n", + " / (padding_mask.sum(dim=1).clamp(min=1.0))\n", + " ).mean()\n", + " return loss" + ] + }, + { + "cell_type": "markdown", + "id": "d4f8bbe3-b7ac-4905-b197-f10990f9a104", + "metadata": {}, + "source": [ + "## Define Reward" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "163e98bf-e0f5-4ec3-9690-9839e687f9b3", + "metadata": {}, + "outputs": [], + "source": [ + "@dataclass\n", + "class RewardActor(ForgeActor):\n", + " \"\"\"Reward actor that uses a list of scoring functions.\"\"\"\n", + "\n", + " reward_functions: list[Callable]\n", + "\n", + " @endpoint\n", + " async def evaluate_response(self, prompt: str, response: str, target: str) -> float:\n", + " total_rewards = 0.0\n", + " for reward_fn in self.reward_functions:\n", + " reward = reward_fn(prompt, response, target)\n", + " total_rewards += reward\n", + "\n", + " # Get a name for the reward function (works for classes, functions, lambdas)\n", + " reward_fn_name = getattr(\n", + " reward_fn, \"__name__\", reward_fn.__class__.__name__\n", + " )\n", + " # per function reward\n", + " record_metric(\n", + " f\"reward/evaluate_response/sum_{reward_fn_name}_reward\",\n", + " reward,\n", + " Reduce.SUM,\n", + " )\n", + " record_metric(\n", + " f\"reward/evaluate_response/avg_{reward_fn_name}_reward\",\n", + " reward,\n", + " Reduce.MEAN,\n", + " )\n", + " record_metric(\n", + " f\"reward/evaluate_response/std_{reward_fn_name}_reward\",\n", + " reward,\n", + " Reduce.STD,\n", + " )\n", + "\n", + " # avg total reward\n", + " record_metric(\n", + " \"reward/evaluate_response/avg_total_reward\",\n", + " reward,\n", + " Reduce.MEAN,\n", + " )\n", + "\n", + " # count fn calls\n", + " record_metric(\n", + " f\"reward/evaluate_response/count_{reward_fn_name}_calls\",\n", + " 1,\n", + " Reduce.SUM,\n", + " )\n", + "\n", + " avg_reward = total_rewards / len(self.reward_functions)\n", + " return avg_reward\n", + "\n", + "\n", + "@dataclass\n", + "class ComputeAdvantages(ForgeActor):\n", + " \"\"\"Compute advantages for GRPO using reward signals.\"\"\"\n", + "\n", + " @endpoint\n", + " async def compute(self, group: Group) -> list[float]:\n", + " # TODO: add batch processing\n", + " rewards = torch.tensor([[e.reward for e in group.episodes]])\n", + " mean = rewards.mean(1, keepdim=True)\n", + " std = rewards.std(1, keepdim=True)\n", + " advantages = (rewards - mean) / (std + 1e-4)\n", + " return advantages.squeeze(0).tolist()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "88523484-b414-41db-bd3f-0d8dbf881a85", + "metadata": {}, + "outputs": [], + "source": [ + "async def drop_weights(version: int):\n", + " print(f\"Dropping weights @ version {version}\")\n", + " start_time = time.perf_counter()\n", + " prefix = get_param_prefix(version)\n", + " matching_keys = await ts.keys(prefix)\n", + " # TODO: once we have something like `get_meta()` in torchstore, we can just\n", + " # query the type of the object instead of relying on keys.\n", + " dcp_key = get_dcp_whole_state_dict_key(version)\n", + " if dcp_key in matching_keys:\n", + " dcp_handle = await ts.get(dcp_key)\n", + " dcp_handle.drop()\n", + " for key in matching_keys:\n", + " await ts.delete(key)\n", + " elapsed = time.perf_counter() - start_time\n", + " print(f\"Dropped weights @ version {version}, took {elapsed:.2f} seconds\")" + ] + }, + { + "cell_type": "markdown", + "id": "95d4fef3-180b-4b7e-8871-ecbe113cde72", + "metadata": {}, + "source": [ + "## Setup Services" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6c811974-cd6b-40ed-a179-4511a7a6c489", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "from omegaconf import OmegaConf\n", + "from forge.cli.config import resolve_hf_hub_paths\n", + "\n", + "cfg = OmegaConf.load('apps/grpo/qwen3_1_7b.yaml')\n", + "cfg = resolve_hf_hub_paths(cfg)\n", + "OmegaConf.resolve(cfg)\n", + "\n", + "group_size = cfg.group_size # 8\n", + "max_req_tokens = cfg.max_req_tokens # 512\n", + "max_res_tokens = cfg.max_res_tokens # 512\n", + "\n", + "metric_logging_cfg = cfg.get(\"metric_logging\", {\"console\": {\"log_per_rank\": False}})\n", + "mlogger = await get_or_create_metric_logger()\n", + "await mlogger.init_backends.call_one(metric_logging_cfg)\n", + "await ts.initialize(strategy=ts.ControllerStorageVolumes())\n", + "\n", + "dataloader, policy, trainer, replay_buffer, compute_advantages, ref_model, reward_actor = await asyncio.gather(\n", + " DatasetActor.options(**cfg.actors.dataset).as_actor(**cfg.dataset),\n", + " Policy.options(**cfg.services.policy).as_service(**cfg.policy),\n", + " RLTrainer.options(**cfg.actors.trainer).as_actor(\n", + " **cfg.trainer, loss=simple_grpo_loss\n", + " ),\n", + " ReplayBuffer.options(**cfg.actors.replay_buffer).as_actor(\n", + " **cfg.replay_buffer, collate=collate\n", + " ),\n", + " ComputeAdvantages.options(**cfg.actors.compute_advantages).as_actor(),\n", + " ReferenceModel.options(**cfg.services.ref_model).as_service(**cfg.ref_model),\n", + " RewardActor.options(**cfg.services.reward_actor).as_service(\n", + " reward_functions=[MathReward(), ThinkingReward()]\n", + " ),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "3f2a305f-b1e2-4eac-803c-71bf3225fed7", + "metadata": {}, + "source": [ + "## Rollout Loop" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c1c676fb-2cd6-4c2c-87d4-e1b8cd0b87af", + "metadata": {}, + "outputs": [], + "source": [ + "async def continuous_rollouts():\n", + " rollout_count = 0\n", + " pad_id = await dataloader.pad_token.call_one()\n", + " while True:\n", + " t = Tracer(\"main_perf/continuous_rollouts\")\n", + " t.start()\n", + " sample = await dataloader.sample.call_one()\n", + " if sample is None:\n", + " print(\"Dataloader is empty, exiting continuous rollout\")\n", + " return\n", + "\n", + " t.step(\"data_loading\")\n", + "\n", + " prompt, target = sample[\"request\"], sample[\"target\"]\n", + " responses = await policy.generate.route(prompt)\n", + " # TODO: this shall be part of the responses metadata instead of a separate call\n", + " version = await policy.get_version.route()\n", + "\n", + " t.step(\"policy_generation\")\n", + "\n", + " assert (\n", + " len(responses) > 0\n", + " ), \"Sanity check: Responses should NEVER return empty\"\n", + " assert (\n", + " version := responses[0].generator_version\n", + " ) is not None, \"Response must indicate a version\"\n", + " group = Group.new_group(\n", + " group_id=rollout_count,\n", + " group_size=group_size,\n", + " request=prompt,\n", + " policy_version=version,\n", + " pad_id=pad_id,\n", + " request_len=max_req_tokens,\n", + " response_len=max_res_tokens,\n", + " target=target,\n", + " )\n", + "\n", + " input_ids = torch.ones(\n", + " (group_size, max_req_tokens + max_res_tokens),\n", + " dtype=torch.long,\n", + " device=\"cuda\",\n", + " )\n", + " # Populate episode info and calculate rewards\n", + " for i, (episode, response) in enumerate(zip(group.episodes, responses)):\n", + " episode.request_tokens = response.prompt_ids\n", + " episode.response_tokens = response.token_ids\n", + " episode.response = response.text\n", + " input_ids[i, :max_req_tokens] = episode.request_tensor\n", + " input_ids[i, max_req_tokens:] = episode.response_tensor\n", + " episode.reward = await reward_actor.evaluate_response.route(\n", + " prompt=prompt, response=response.text, target=target\n", + " )\n", + "\n", + " t.step(\"reward_evaluation\")\n", + "\n", + " ref_logprobs = await ref_model.forward.route(\n", + " input_ids, max_req_tokens, return_logprobs=True\n", + " )\n", + " t.step(\"reference_model_calculate_logprobs\")\n", + "\n", + " for i, episode in enumerate(group.episodes):\n", + " episode.ref_logprobs = ref_logprobs[i]\n", + " del ref_logprobs, input_ids\n", + " t.step(\"compute_logprobs\")\n", + "\n", + " # Calculate advantages and add to replay buffer\n", + " advantages = await compute_advantages.compute.call_one(group)\n", + " for episode, advantage in zip(group.episodes, advantages):\n", + " episode.advantage = advantage\n", + " await replay_buffer.add.call_one(episode)\n", + "\n", + " # Log metrics\n", + " rollout_count += 1\n", + " record_metric(\n", + " \"main/continuous_rollouts/count_rollout_iterations\", 1, Reduce.SUM\n", + " )\n", + " t.stop()" + ] + }, + { + "cell_type": "markdown", + "id": "57c316dc-11b5-48ea-8b03-e1bb9d9d1f2b", + "metadata": {}, + "source": [ + "## Training Loop" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "916a0e79-aded-4ee3-b1a8-db0e772996c9", + "metadata": {}, + "outputs": [], + "source": [ + "async def continuous_training():\n", + " training_step = 0\n", + " restart_tracer = True # Flag to control when to restart tracer\n", + " while True:\n", + " # Restart tracer when needed (initial start or after completing a training step)\n", + " # Otherwise, we cannot measure time waiting for buffer\n", + " if restart_tracer:\n", + " t = Tracer(\"main_perf/continuous_training\")\n", + " t.start()\n", + " restart_tracer = False\n", + "\n", + " batch = await replay_buffer.sample.call_one(\n", + " curr_policy_version=training_step\n", + " )\n", + " if batch is None:\n", + " await asyncio.sleep(0.1)\n", + " else:\n", + " t.step(\"waiting_for_buffer\")\n", + "\n", + " inputs, targets = batch\n", + " await trainer.train_step.call(inputs, targets)\n", + " training_step += 1\n", + " t.step(\"train_step\")\n", + "\n", + " await trainer.push_weights.call(training_step)\n", + " t.step(\"push_weights\")\n", + "\n", + " await policy.update_weights.fanout(training_step)\n", + " update_task = asyncio.create_task(policy.update_weights.fanout(training_step))\n", + " t.step(\"update_weights\")\n", + "\n", + " if training_step >= 2:\n", + " await drop_weights(training_step - 1)\n", + " t.step(\"drop_weights\")\n", + "\n", + " t.stop()\n", + " restart_tracer = True\n", + "\n", + " # Flush metrics every training step to WandB\n", + " await mlogger.flush.call_one(training_step)" + ] + }, + { + "cell_type": "markdown", + "id": "4542863b-59c5-40bc-896c-6d8d44ada00f", + "metadata": {}, + "source": [ + "## Run" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "58194c13-b75e-405d-ab11-18cbe1874d92", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "num_rollout_threads = 1\n", + "num_training_threads = 1\n", + "\n", + "rollout_tasks = [\n", + " asyncio.create_task(continuous_rollouts()) for _ in range(num_rollout_threads)\n", + "]\n", + "training_task = asyncio.create_task(continuous_training())\n", + "\n", + "try:\n", + " await asyncio.gather(*rollout_tasks, training_task)\n", + "except KeyboardInterrupt:\n", + " print(\"Training interrupted by user\")\n", + " for rollout_task in rollout_tasks:\n", + " rollout_task.cancel()\n", + " training_task.cancel()" + ] + }, + { + "cell_type": "markdown", + "id": "b4603b80-1f25-49a1-920e-d24f38dfc687", + "metadata": {}, + "source": [ + "## Shutdown" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0d74e781-c253-4bd0-929f-bd4ad516ba81", + "metadata": {}, + "outputs": [], + "source": [ + "await mlogger.shutdown.call_one()\n", + "await asyncio.sleep(2)\n", + "\n", + "await asyncio.gather(\n", + " DatasetActor.shutdown(dataloader),\n", + " policy.shutdown(),\n", + " RLTrainer.shutdown(trainer),\n", + " ReplayBuffer.shutdown(replay_buffer),\n", + " ComputeAdvantages.shutdown(compute_advantages),\n", + " ref_model.shutdown(),\n", + " reward_actor.shutdown(),\n", + ")\n", + "await shutdown()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "forge", + "language": "python", + "name": "forge" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.18" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}