diff --git a/apps/grpo/main.py b/apps/grpo/main.py index d4143edb2..cb97e08a3 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -23,7 +23,7 @@ from forge.actors.generator import Generator from forge.actors.reference_model import ReferenceModel from forge.actors.replay_buffer import ReplayBuffer -from forge.actors.trainer import RLTrainer +from forge.actors.trainer import TitanTrainer from forge.controller.actor import ForgeActor from forge.controller.provisioner import init_provisioner, shutdown from forge.data.rewards import MathReward, ThinkingReward @@ -318,7 +318,7 @@ async def main(cfg: DictConfig): ) = await asyncio.gather( DatasetActor.options(**cfg.actors.dataset).as_actor(**cfg.dataset), Policy.options(**cfg.services.policy).as_service(**cfg.policy), - RLTrainer.options(**cfg.actors.trainer).as_actor( + TitanTrainer.options(**cfg.actors.trainer).as_actor( **cfg.trainer, loss=simple_grpo_loss ), ReplayBuffer.options(**cfg.actors.replay_buffer).as_actor( diff --git a/docs/source/api_trainer.md b/docs/source/api_trainer.md index 75aba94f0..6e66e5418 100644 --- a/docs/source/api_trainer.md +++ b/docs/source/api_trainer.md @@ -7,17 +7,17 @@ The Trainer manages model training in TorchForge, built on top of TorchTitan. It handles forward/backward passes, weight updates, and checkpoint management for reinforcement learning workflows. -## RLTrainer +## TitanTrainer ```{eval-rst} -.. autoclass:: RLTrainer +.. autoclass:: TitanTrainer :members: train_step, push_weights, cleanup :exclude-members: __init__ ``` ## Configuration -The RLTrainer uses TorchTitan's configuration system with the following components: +The TitanTrainer uses TorchTitan's configuration system with the following components: ### Job Configuration diff --git a/docs/source/tutorial_sources/zero-to-forge/1_RL_and_Forge_Fundamentals.md b/docs/source/tutorial_sources/zero-to-forge/1_RL_and_Forge_Fundamentals.md index bf31926a4..37314831c 100644 --- a/docs/source/tutorial_sources/zero-to-forge/1_RL_and_Forge_Fundamentals.md +++ b/docs/source/tutorial_sources/zero-to-forge/1_RL_and_Forge_Fundamentals.md @@ -96,7 +96,7 @@ graph LR S3["RewardActor"] S4["ReferenceModel"] S5["ReplayBuffer"] - S6["RLTrainer"] + S6["TitanTrainer"] end C1 --> S1 @@ -306,7 +306,7 @@ TorchForge handles behind the scenes: from forge.actors.generator import Generator as Policy from forge.actors.replay_buffer import ReplayBuffer from forge.actors.reference_model import ReferenceModel -from forge.actors.trainer import RLTrainer +from forge.actors.trainer import TitanTrainer from apps.grpo.main import DatasetActor, RewardActor, ComputeAdvantages from forge.data.rewards import MathReward, ThinkingReward import asyncio @@ -348,7 +348,7 @@ group_size = 1 } ), # Trainer actor with GPU - RLTrainer.options(procs=1, with_gpus=True).as_actor( + TitanTrainer.options(procs=1, with_gpus=True).as_actor( # Trainer config would come from YAML in real usage model={"name": "qwen3", "flavor": "1.7B", "hf_assets_path": f"hf://{model}"}, optimizer={"name": "AdamW", "lr": 1e-5}, @@ -378,12 +378,12 @@ group_size = 1 TorchForge has two types of distributed components: - **Services**: Multiple replicas with automatic load balancing (like Policy, RewardActor) -- **Actors**: Single instances that handle their own internal distribution (like RLTrainer, ReplayBuffer) +- **Actors**: Single instances that handle their own internal distribution (like TitanTrainer, ReplayBuffer) We cover this distinction in detail in Part 2, but for now this explains the scaling patterns: - Policy service: num_replicas=8 for high inference demand - RewardActor service: num_replicas=16 for parallel evaluation -- RLTrainer actor: Single instance with internal distributed training +- TitanTrainer actor: Single instance with internal distributed training ### Fault Tolerance diff --git a/docs/source/tutorial_sources/zero-to-forge/2_Forge_Internals.md b/docs/source/tutorial_sources/zero-to-forge/2_Forge_Internals.md index 1dabe0d5f..cc6cfeda3 100644 --- a/docs/source/tutorial_sources/zero-to-forge/2_Forge_Internals.md +++ b/docs/source/tutorial_sources/zero-to-forge/2_Forge_Internals.md @@ -470,7 +470,7 @@ async def simple_rl_step(): if batch is not None: print("Training on batch...") inputs, targets = batch # GRPO returns (inputs, targets) tuple - loss = await trainer.train_step.call(inputs, targets) # RLTrainer is an actor + loss = await trainer.train_step.call(inputs, targets) # TitanTrainer is an actor print(f"Training loss: {loss}") return loss else: @@ -507,7 +507,7 @@ reward_actor = await RewardActor.options( ) # Training needs fewer but more powerful replicas -trainer = await RLTrainer.options( +trainer = await TitanTrainer.options( procs=1, with_gpus=True # Fewer but GPU-heavy ).as_actor( # Trainer typically uses .as_actor() not .as_service() model={"name": "qwen3", "flavor": "1.7B"}, @@ -580,7 +580,7 @@ import torch from forge.actors.generator import Generator as Policy from forge.actors.reference_model import ReferenceModel from forge.actors.replay_buffer import ReplayBuffer -from forge.actors.trainer import RLTrainer +from forge.actors.trainer import TitanTrainer from apps.grpo.main import DatasetActor, RewardActor, ComputeAdvantages from forge.data.rewards import MathReward, ThinkingReward @@ -603,7 +603,7 @@ print("Initializing all services...") engine_config={"model": "Qwen/Qwen3-1.7B", "tensor_parallel_size": 1}, sampling_config={"n": 1, "max_tokens": 512} ), - RLTrainer.options(procs=1, with_gpus=True).as_actor( + TitanTrainer.options(procs=1, with_gpus=True).as_actor( model={"name": "qwen3", "flavor": "1.7B", "hf_assets_path": "hf://Qwen/Qwen3-1.7B"}, optimizer={"name": "AdamW", "lr": 1e-5}, training={"local_batch_size": 2, "seq_len": 2048} @@ -667,7 +667,7 @@ print("Shutting down services...") await asyncio.gather( DatasetActor.shutdown(dataloader), policy.shutdown(), - RLTrainer.shutdown(trainer), + TitanTrainer.shutdown(trainer), ReplayBuffer.shutdown(replay_buffer), ComputeAdvantages.shutdown(compute_advantages), ReferenceModel.shutdown(ref_model), diff --git a/src/forge/actors/__init__.py b/src/forge/actors/__init__.py index 00e58b9dd..772e2e216 100644 --- a/src/forge/actors/__init__.py +++ b/src/forge/actors/__init__.py @@ -4,9 +4,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import warnings + __all__ = [ "Generator", - "RLTrainer", + "TitanTrainer", + "RLTrainer", # Deprecated, use TitanTrainer "ReplayBuffer", "ReferenceModel", "SandboxedPythonCoder", @@ -18,7 +21,17 @@ def __getattr__(name): from .generator import Generator return Generator + elif name == "TitanTrainer": + from .trainer import TitanTrainer + + return TitanTrainer elif name == "RLTrainer": + warnings.warn( + "RLTrainer is deprecated and will be removed in a future version. " + "Please use TitanTrainer instead.", + FutureWarning, + stacklevel=2, + ) from .trainer import RLTrainer return RLTrainer diff --git a/src/forge/actors/trainer/__init__.py b/src/forge/actors/trainer/__init__.py new file mode 100644 index 000000000..8978ab76d --- /dev/null +++ b/src/forge/actors/trainer/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import warnings + +from .titan import TitanTrainer + +__all__ = ["TitanTrainer", "RLTrainer"] + + +def __getattr__(name): + if name == "RLTrainer": + warnings.warn( + "RLTrainer is deprecated and will be removed in a future version. " + "Please use TitanTrainer instead.", + FutureWarning, + stacklevel=2, + ) + return TitanTrainer + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer/titan.py similarity index 98% rename from src/forge/actors/trainer.py rename to src/forge/actors/trainer/titan.py index c58600f51..e00fe92be 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer/titan.py @@ -53,8 +53,8 @@ @dataclass -class RLTrainer(ForgeActor): - """A reinforcement learning trainer actor for policy optimization training. +class TitanTrainer(ForgeActor): + """A generic trainer actor implementation built on top of TorchTitan. Built on top of TorchTitan's training engine, this actor provides a complete training loop for reinforcement learning. It performs forward and backward passes with gradient diff --git a/tests/integration_tests/test_policy_update.py b/tests/integration_tests/test_policy_update.py index 202c10686..d4151b5b6 100644 --- a/tests/integration_tests/test_policy_update.py +++ b/tests/integration_tests/test_policy_update.py @@ -16,7 +16,7 @@ import torchstore as ts from forge.actors.generator import Generator -from forge.actors.trainer import RLTrainer +from forge.actors.trainer import TitanTrainer from forge.controller.provisioner import init_provisioner from forge.controller.service.service import uuid @@ -50,7 +50,7 @@ TEST_DCP_DIR = "test_dcp_tmp" -class MockRLTrainer(RLTrainer): +class MockTitanTrainer(TitanTrainer): @endpoint async def zero_out_model_states(self): """This simply sets all model weights to zero.""" @@ -59,7 +59,7 @@ async def zero_out_model_states(self): for k in sd.keys(): if not torch.is_floating_point(sd[k]): logger.info( - f"[MockRLTrainer] zero_out_model_states(): skipping non-float param {k}" + f"[MockTitanTrainer] zero_out_model_states(): skipping non-float param {k}" ) continue sd[k] *= 0.0 @@ -199,14 +199,14 @@ async def _setup_and_teardown(request): ) await ts.initialize(strategy=ts.ControllerStorageVolumes()) - policy, rl_trainer = await asyncio.gather( + policy, titan_trainer = await asyncio.gather( *[ Generator.options(**services_policy_cfg).as_service(**cfg.policy), - MockRLTrainer.options(**cfg.actors.trainer).as_actor(**trainer_cfg), + MockTitanTrainer.options(**cfg.actors.trainer).as_actor(**trainer_cfg), ] ) - yield policy, rl_trainer + yield policy, titan_trainer # ---- teardown ---- # logger.info("Shutting down services and cleaning up DCP directory..") @@ -214,7 +214,7 @@ async def _setup_and_teardown(request): await asyncio.gather( policy.shutdown(), ts.shutdown(), - RLTrainer.shutdown(rl_trainer), + TitanTrainer.shutdown(titan_trainer), ) # Cleanup DCP directory @@ -235,7 +235,7 @@ class TestWeightSync: @requires_cuda async def test_sanity_check(self, _setup_and_teardown): """ - Sanity check for weight sync sharding between RLTrainer and Policy for a given model config. + Sanity check for weight sync sharding between TitanTrainer and Policy for a given model config. The check performs the following steps: - Initialize trainer and push weights v0 (original huggingface ckpt) @@ -245,15 +245,15 @@ async def test_sanity_check(self, _setup_and_teardown): """ - policy, rl_trainer = _setup_and_teardown + policy, titan_trainer = _setup_and_teardown v0 = uuid.uuid4().int v1 = v0 + 1 - await rl_trainer.push_weights.call(policy_version=v0) + await titan_trainer.push_weights.call(policy_version=v0) # Setting everything to zero - await rl_trainer.zero_out_model_states.call() - await rl_trainer.push_weights.call(policy_version=v1) + await titan_trainer.zero_out_model_states.call() + await titan_trainer.push_weights.call(policy_version=v1) await policy.save_model_params.fanout() # Sanity check that before update all the tests pass diff --git a/tests/sandbox/rl_trainer/main.py b/tests/sandbox/rl_trainer/main.py index 55714c49d..8825794b6 100644 --- a/tests/sandbox/rl_trainer/main.py +++ b/tests/sandbox/rl_trainer/main.py @@ -10,7 +10,7 @@ import torch import torchstore as ts -from forge.actors.trainer import RLTrainer +from forge.actors.trainer import TitanTrainer from forge.controller.launcher import JOB_NAME_KEY, LAUNCHER_KEY from forge.controller.provisioner import init_provisioner, shutdown from forge.observability.metric_actors import get_or_create_metric_logger @@ -182,7 +182,7 @@ async def main(cfg: DictConfig): await ts.initialize(strategy=ts.ControllerStorageVolumes()) # Initialize trainer only print("Initializing trainer...") - trainer = await RLTrainer.options(**cfg.actors.trainer).as_actor( + trainer = await TitanTrainer.options(**cfg.actors.trainer).as_actor( **cfg.trainer, loss=simple_grpo_loss ) print("Trainer initialized successfully with following configs!")