diff --git a/apps/grpo/main.py b/apps/grpo/main.py index f2b7f6e7d..7d942e04b 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -264,15 +264,15 @@ async def main(cfg: DictConfig): ref_model, reward_actor, ) = await asyncio.gather( - DatasetActor.options(**cfg.services.dataset).as_service(**cfg.dataset), + DatasetActor.options(**cfg.actors.dataset).as_actor(**cfg.dataset), Policy.options(**cfg.services.policy).as_service(**cfg.policy), - RLTrainer.options(**cfg.services.trainer).as_service( + RLTrainer.options(**cfg.actors.trainer).as_actor( **cfg.trainer, loss=simple_grpo_loss ), - ReplayBuffer.options(**cfg.services.replay_buffer).as_service( + ReplayBuffer.options(**cfg.actors.replay_buffer).as_actor( **cfg.replay_buffer, collate=collate ), - ComputeAdvantages.options(**cfg.services.compute_advantages).as_service(), + ComputeAdvantages.options(**cfg.actors.compute_advantages).as_actor(), ReferenceModel.options(**cfg.services.ref_model).as_service(**cfg.ref_model), RewardActor.options(**cfg.services.reward_actor).as_service( reward_functions=[MathReward(), ThinkingReward()] @@ -283,9 +283,9 @@ async def main(cfg: DictConfig): # ---- Core RL loops ---- # async def continuous_rollouts(): rollout_count = 0 - pad_id = await dataloader.pad_token.route() + pad_id = await dataloader.pad_token.call_one() while True: - sample = await dataloader.sample.route() + sample = await dataloader.sample.call_one() if sample is None: print("Dataloader is empty, exiting continuous rollout") return @@ -332,17 +332,17 @@ async def continuous_rollouts(): del ref_logits, ref_logprobs, input_ids # Calculate advantages and add to replay buffer - advantages = await compute_advantages.compute.route(group) + advantages = await compute_advantages.compute.call_one(group) for episode, advantage in zip(group.episodes, advantages): episode.advantage = advantage - await replay_buffer.add.route(episode) + await replay_buffer.add.call_one(episode) # Log metrics avg_response_len = ( sum(len(e.response_tokens) for e in group.episodes) / group_size ) mlogger.log("avg_response_len/rollout", avg_response_len, rollout_count) - buffer_size = await replay_buffer._numel.route() + buffer_size = await replay_buffer._numel.call_one() mlogger.log("buffer_size/rollout", buffer_size, rollout_count) avg_reward = sum(e.reward for e in group.episodes) / group_size mlogger.log("avg_reward/rollout", avg_reward, rollout_count) @@ -352,15 +352,18 @@ async def continuous_rollouts(): async def continuous_training(): training_step = 0 while True: - batch = await replay_buffer.sample.route(curr_policy_version=training_step) + batch = await replay_buffer.sample.call_one( + curr_policy_version=training_step + ) if batch is None: await asyncio.sleep(0.1) else: inputs, targets = batch - loss = await trainer.train_step.route(inputs, targets) + loss = await trainer.train_step.call_one(inputs, targets) training_step += 1 mlogger.log("loss/training_step", loss, training_step) - await trainer.push_weights.fanout(training_step) + + await trainer.push_weights.call(training_step) await policy.update_weights.fanout(training_step) print("Starting GRPO training loops...") @@ -377,11 +380,11 @@ async def continuous_training(): finally: print("Shutting down...") await asyncio.gather( - dataloader.shutdown(), + DatasetActor.shutdown(dataloader), policy.shutdown(), - trainer.shutdown(), - replay_buffer.shutdown(), - compute_advantages.shutdown(), + RLTrainer.shutdown(trainer), + ReplayBuffer.shutdown(replay_buffer), + ComputeAdvantages.shutdown(compute_advantages), ref_model.shutdown(), reward_actor.shutdown(), ) diff --git a/apps/grpo/qwen3_1_7b.yaml b/apps/grpo/qwen3_1_7b.yaml index 7f1a65e1a..05171e3e9 100644 --- a/apps/grpo/qwen3_1_7b.yaml +++ b/apps/grpo/qwen3_1_7b.yaml @@ -101,31 +101,29 @@ ref_model: # All resource allocations services: - dataset: - procs: 1 - num_replicas: 1 - with_gpus: false policy: procs: ${policy.engine_config.tensor_parallel_size} num_replicas: 1 with_gpus: true - trainer: + ref_model: procs: 1 num_replicas: 1 with_gpus: true - replay_buffer: + reward_actor: procs: 1 num_replicas: 1 with_gpus: false - ref_model: + +actors: + dataset: + procs: 1 + with_gpus: false + trainer: procs: 1 - num_replicas: 1 with_gpus: true - compute_advantages: + replay_buffer: procs: 1 - num_replicas: 1 with_gpus: false - reward_actor: + compute_advantages: procs: 1 - num_replicas: 1 with_gpus: false diff --git a/apps/grpo/qwen3_8b.yaml b/apps/grpo/qwen3_8b.yaml index 51ca387e5..1d2a2e59e 100644 --- a/apps/grpo/qwen3_8b.yaml +++ b/apps/grpo/qwen3_8b.yaml @@ -106,31 +106,29 @@ ref_model: # All resource allocations services: - dataset: - procs: 1 - num_replicas: 1 - with_gpus: false policy: procs: ${policy.engine_config.tensor_parallel_size} num_replicas: 1 with_gpus: true - trainer: - procs: 2 + ref_model: + procs: 1 num_replicas: 1 with_gpus: true - replay_buffer: + reward_actor: procs: 1 num_replicas: 1 with_gpus: false - ref_model: + +actors: + dataset: procs: 1 - num_replicas: 1 + with_gpus: false + trainer: + procs: 2 with_gpus: true - compute_advantages: + replay_buffer: procs: 1 - num_replicas: 1 with_gpus: false - reward_actor: + compute_advantages: procs: 1 - num_replicas: 1 with_gpus: false diff --git a/apps/grpo/qwen3_multinode.yaml b/apps/grpo/qwen3_multinode.yaml index ade01855f..cc0c913cf 100644 --- a/apps/grpo/qwen3_multinode.yaml +++ b/apps/grpo/qwen3_multinode.yaml @@ -46,33 +46,31 @@ ref_model: model_name: ${model} services: - dataset: - procs: 1 - num_replicas: 1 - with_gpus: false policy: procs: 1 hosts: 1 num_replicas: 1 with_gpus: true - trainer: + ref_model: procs: 1 - hosts: 1 num_replicas: 1 with_gpus: true - replay_buffer: + reward_actor: procs: 1 num_replicas: 1 with_gpus: false + +actors: + dataset: + procs: 1 + with_gpus: false compute_advantages: procs: 1 - num_replicas: 1 with_gpus: false - ref_model: + trainer: procs: 1 - num_replicas: 1 + hosts: 1 with_gpus: true - reward_actor: + replay_buffer: procs: 1 - num_replicas: 1 with_gpus: false diff --git a/apps/rl/__init__.py b/apps/rl/__init__.py deleted file mode 100644 index 2e41cd717..000000000 --- a/apps/rl/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. diff --git a/apps/rl/llama3_8b.yaml b/apps/rl/llama3_8b.yaml deleted file mode 100644 index a7d2e3d96..000000000 --- a/apps/rl/llama3_8b.yaml +++ /dev/null @@ -1,62 +0,0 @@ -# Config for GRPO finetuning using a Llama3.1 8B Instruct model -# -# This config assumes that you've run the following command before launching -# this run: -# export HF_HUB_DISABLE_XET=1 -# uv run forge download meta-llama/Meta-Llama-3.1-8B-Instruct - - -trainer: - comm: - trace_buf_size: 0 - - model: - name: llama3 - flavor: 8B - hf_assets_path: /tmp/Meta-Llama-3.1-8B-Instruct - - optimizer: - name: AdamW - lr: 1e-5 - eps: 1e-8 - - lr_scheduler: - warmup_steps: 1 - - training: - local_batch_size: 1 - seq_len: 2048 - max_norm: 1.0 - steps: 5 - dataset: "c4" - - compile: - enable: false - - parallelism: - data_parallel_replicate_degree: 1 - data_parallel_shard_degree: 4 - tensor_parallel_degree: 1 - pipeline_parallel_degree: 1 - context_parallel_degree: 1 - expert_parallel_degree: 1 - disable_loss_parallel: false - - checkpoint: - enable: true - folder: /tmp/Meta-Llama-3.1-8B-Instruct/saved_checkpoints - initial_load_path: /tmp/Meta-Llama-3.1-8B-Instruct/ - initial_load_in_hf: true - last_save_in_hf: true - interval: 500 - async_mode: "disabled" - - activation_checkpoint: - mode: selective - selective_ac_option: op - -replay_buffer: - batch_size: 4 - max_policy_age: 2 - seed: None - dp_size: 4 diff --git a/apps/rl/main.py b/apps/rl/main.py deleted file mode 100644 index 49084b50b..000000000 --- a/apps/rl/main.py +++ /dev/null @@ -1,182 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -"""A working example showcasing a practical example of forge with RL. - -Run this with: - python -m apps.rl.main --config apps/rl/llama3_8b.yaml - -""" - -import asyncio -import logging -import sys -from dataclasses import dataclass -from typing import Any - -import torch -import torch.nn.functional as F -from forge.actors import ReplayBuffer, RLTrainer -from forge.cli.config import parse - -from omegaconf import DictConfig -from torch import Tensor - -logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) - - -@dataclass -class Episode: - # TODO: add adtional layer for multi-turn - episode_id: str - request: str - policy_version: int - pad_id: int - request_len: int - response_len: int - target: Any | None = None - # processed data - response: str | None = None - request_tokens: list[int] | None = None - response_tokens: list[int] | None = None - ref_logprobs: Tensor | None = None - reward: float | None = None - advantage: float | None = None - - @property - def request_tensor(self): - tensor = torch.tensor(self.request_tokens, dtype=torch.long) - if tensor.shape[0] < self.request_len: # left pad - diff = self.request_len - tensor.shape[0] - tensor = F.pad(tensor, (diff, 0), value=self.pad_id) - return tensor - - @property - def response_tensor(self): - tensor = torch.tensor(self.response_tokens, dtype=torch.long) - if tensor.shape[0] < self.response_len: # right pad - diff = self.response_len - tensor.shape[0] - tensor = F.pad(tensor, (0, diff), value=self.pad_id) - return tensor - - -def collate(batches: list[list[Episode]]): - inputs = [] - targets = [] - for batch in batches: - request = [e.request_tensor for e in batch] - request = torch.stack(request) # [b x s] - - response = [e.response_tensor for e in batch] - response = torch.stack(response) # [b x s] - - ref_logprobs = [e.ref_logprobs for e in batch] - ref_logprobs = torch.stack(ref_logprobs).squeeze() # [b x s] - - advantages = [e.advantage for e in batch] - advantages = torch.tensor(advantages).unsqueeze(-1) # [b x 1] - - pad_id = batch[0].pad_id - mask = response != pad_id - - input = {"tokens": torch.cat([request, response], dim=1)} - target = { - "response": response, - "ref_logprobs": ref_logprobs, - "advantages": advantages, - "padding_mask": mask, - } - inputs.append(input) - targets.append(target) - return inputs, targets - - -def compute_logprobs( - logits: Tensor, input_ids: torch.Tensor, temperature: float = 1.0 -) -> Tensor: - context_length = logits.shape[1] - input_ids.shape[1] - - # Truncate request logits and drop last - logits = logits[:, context_length - 1 : -1] - - # Compute logprobs - logprobs = torch.log_softmax(logits / temperature, dim=-1) - logprobs = torch.gather(logprobs, 2, input_ids.unsqueeze(-1)).squeeze(-1) - - return logprobs - - -def simple_grpo_loss( - logits: Tensor, - response: Tensor, - ref_logprobs: Tensor, - advantages: Tensor, - padding_mask: Tensor, - beta: float = 0.1, -): - """Simplified GRPO Loss for simplified single step updates - Copied from https://github.com/pytorch/torchtune/blob/main/torchtune/dev/grpo/loss.py. - """ - logprobs = compute_logprobs(logits, response) - per_token_kl = ( - torch.exp(ref_logprobs.detach() - logprobs) - - (ref_logprobs.detach() - logprobs) - - 1 - ) - per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages - per_token_loss = -(per_token_policy_loss - beta * per_token_kl) - loss = ( - (per_token_loss * padding_mask).sum(dim=1) / (padding_mask.sum(dim=1) + 1e-8) - ).mean() - return loss - - -async def run(cfg: DictConfig): - trainer = await RLTrainer.options( - procs=1, with_gpus=True, num_replicas=4 - ).as_service(**cfg.trainer) - replay_buffer = await ReplayBuffer.options(procs=1, num_replicas=1).as_service( - **cfg.replay_buffer - ) - - print("Services initialized....") - - print("Collecting Data...") - g = torch.manual_seed(0) - global_batch_size = cfg.replay_buffer.batch_size * cfg.replay_buffer.dp_size - for i in range(global_batch_size): - req_len, res_len = torch.randint(64, 256, (2,), generator=g).tolist() - e = Episode( - episode_id=i, - request="", - policy_version=0, - pad_id=0, - request_len=256, - response_len=256, - request_tokens=torch.randint(64_000, (req_len,), generator=g).tolist(), - response_tokens=torch.randint(64_000, (res_len,), generator=g).tolist(), - ref_logprobs=torch.randn((256,), generator=g), - advantage=torch.randn((1,), generator=g), - ) - await replay_buffer.add.route(e) - - print("Train step...") - inputs, targets = await replay_buffer.sample.route(curr_policy_version=0) - outputs = await trainer.train_step.route(inputs, targets) - print("Loss: ", outputs["loss"]) - - print("Shutting down...") - await trainer.shutdown() - await replay_buffer.shutdown() - - -@parse -def recipe_main(cfg: DictConfig) -> None: - asyncio.run(run(cfg)) - - -if __name__ == "__main__": - sys.exit(recipe_main()) diff --git a/apps/toy_rl/sumdigits-tp.yaml b/apps/toy_rl/sumdigits-tp.yaml index 1f2b30d44..87f58d5ea 100644 --- a/apps/toy_rl/sumdigits-tp.yaml +++ b/apps/toy_rl/sumdigits-tp.yaml @@ -42,22 +42,10 @@ replay_buffer: dp_size: 1 services: - dataset: - procs: 1 - num_replicas: 1 - with_gpus: false policy: procs: 1 num_replicas: 1 with_gpus: true - trainer: - procs: 1 - num_replicas: 1 - with_gpus: true - replay_buffer: - procs: 1 - num_replicas: 1 - with_gpus: false reward_actor: procs: 1 num_replicas: 1 @@ -66,3 +54,14 @@ services: procs: 1 num_replicas: 1 with_gpus: true + +actors: + dataset: + procs: 1 + with_gpus: false + trainer: + procs: 1 + with_gpus: true + replay_buffer: + procs: 1 + with_gpus: false diff --git a/apps/toy_rl/sumdigits.py b/apps/toy_rl/sumdigits.py index 24c1fffb7..57971e9b9 100644 --- a/apps/toy_rl/sumdigits.py +++ b/apps/toy_rl/sumdigits.py @@ -16,9 +16,9 @@ import torch import torch.nn.functional as F import torchstore as ts +from forge.actors._torchstore_utils import get_param_key from forge.actors.policy import Policy from forge.actors.replay_buffer import ReplayBuffer -from forge.actors.torchstore_utils import get_param_key from forge.actors.trainer import _qwen3_hf_to_vllm from forge.cli.config import parse from forge.controller.actor import ForgeActor @@ -481,12 +481,10 @@ async def main(cfg: DictConfig): reward_actor, ref_model, ) = await asyncio.gather( - DatasetActor.options(**cfg.services.dataset).as_service(**cfg.dataset), + DatasetActor.options(**cfg.actors.dataset).as_actor(**cfg.dataset), Policy.options(**cfg.services.policy).as_service(**cfg.policy), - Trainer.options(**cfg.services.trainer).as_service(**cfg.trainer), - ReplayBuffer.options(**cfg.services.replay_buffer).as_service( - **cfg.replay_buffer - ), + Trainer.options(**cfg.actors.trainer).as_actor(**cfg.trainer), + ReplayBuffer.options(**cfg.actors.replay_buffer).as_actor(**cfg.replay_buffer), RewardActor.options(**cfg.services.reward_actor).as_service(), RefModel.options(**cfg.services.ref_model).as_service(**cfg.ref_model), ) @@ -496,10 +494,10 @@ async def main(cfg: DictConfig): # ---- Core RL loops ---- # async def continuous_rollouts(): rollout_count = 0 - pad_id = await dataloader.pad_token.route() + pad_id = await dataloader.pad_token.call_one() while True: # Pass rollout_count for curriculum learning - sample = await dataloader.sample.route(rollout_count) + sample = await dataloader.sample.call_one(rollout_count) if sample is None: print("Dataloader is empty, exiting continuous rollout") return @@ -531,7 +529,7 @@ async def continuous_rollouts(): ) episode.advantage = episode.reward # simple case for now for episode in group.episodes: - await replay_buffer.add.route(episode) + await replay_buffer.add.call_one(episode) avg_response_len = ( sum(len(e.response_tokens) for e in group.episodes) / group_size ) @@ -544,20 +542,20 @@ async def continuous_rollouts(): async def continuous_training(): training_step = 0 while True: - batch = await replay_buffer.sample.route(curr_policy_version=training_step) + batch = await replay_buffer.sample.call_one( + curr_policy_version=training_step + ) if batch is None: await asyncio.sleep(0.1) else: - loss = await trainer.train_step.route(batch[0]) + loss = await trainer.train_step.call_one(batch[0]) training_step += 1 mlogger.log("loss/training_step", loss, training_step) print(f"loss/training_step: {loss} at training step {training_step}") - await trainer.push_weights.fanout( - training_step, vllm_tp_DEPRECATED=policy_tp_size - ) + await trainer.push_weights.call(training_step) await policy.update_weights.fanout(training_step) # NOTE: hard-coded to be on-policy for faster convergence - await replay_buffer.clear.fanout() + await replay_buffer.clear.call() print("Starting training loop.") # TODO: Start multiple rollouts once all serivces support it @@ -573,10 +571,10 @@ async def continuous_training(): finally: print("Shutting down...") await asyncio.gather( - dataloader.shutdown(), + DatasetActor.shutdown(dataloader), policy.shutdown(), - trainer.shutdown(), - replay_buffer.shutdown(), + Trainer.shutdown(trainer), + ReplayBuffer.shutdown(replay_buffer), reward_actor.shutdown(), ) # TODO - add a global shutdown that implicitly shuts down all services diff --git a/apps/toy_rl/sumdigits.yaml b/apps/toy_rl/sumdigits.yaml index 86c6a3c52..767bf7f3b 100644 --- a/apps/toy_rl/sumdigits.yaml +++ b/apps/toy_rl/sumdigits.yaml @@ -13,6 +13,7 @@ dataset: # Policy configuration policy: + use_dcp: false engine_config: model: ${model} tensor_parallel_size: 1 @@ -24,6 +25,7 @@ policy: temperature: 1.0 top_p: 1.0 + # Trainer configuration trainer: model_name: ${model} @@ -40,22 +42,10 @@ replay_buffer: dp_size: 1 services: - dataset: - procs: 1 - num_replicas: 1 - with_gpus: false policy: procs: 1 num_replicas: 1 with_gpus: true - trainer: - procs: 1 - num_replicas: 1 - with_gpus: true - replay_buffer: - procs: 1 - num_replicas: 1 - with_gpus: false reward_actor: procs: 1 num_replicas: 1 @@ -64,3 +54,14 @@ services: procs: 1 num_replicas: 1 with_gpus: true + +actors: + dataset: + procs: 1 + with_gpus: false + trainer: + procs: 1 + with_gpus: true + replay_buffer: + procs: 1 + with_gpus: false diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index e69290346..f39ed3bac 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -134,6 +134,7 @@ class Policy(PolicyInterface): sampling_config: SamplingConfig | Mapping = field(default_factory=SamplingConfig) use_vllm_builtin_load: bool = True available_devices: str | None = None + use_dcp: bool = True # Gets set up by setup sampling_params: SamplingParams | None = None lora_request: LoRARequest | None = None @@ -160,6 +161,7 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] engine_config: EngineConfig | Mapping = EngineConfig(), sampling_config: SamplingConfig | Mapping = SamplingConfig(), available_devices: str | None = None, + use_dcp: bool = True, **kwargs, ) -> "Policy": # Note - get_proc_mesh will set MASTER_ADDR, MASTER_PORT and CUDA_VISIBLE_DEVICES @@ -189,7 +191,7 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] vllm_config = engine_config.create_vllm_config() workers = await worker_procs.spawn( - "vllm_worker", PolicyWorker, vllm_config=vllm_config + "vllm_worker", PolicyWorker, vllm_config=vllm_config, use_dcp=use_dcp ) if isinstance(sampling_config, Mapping):