diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 4a51f7225..b226d83ac 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -13,6 +13,12 @@ from typing import Dict, List import torch + +from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh + +from forge.data.sharding import VLLMSharding +from forge.interfaces import Policy as PolicyInterface +from forge.types import ProcessConfig from monarch.actor import current_rank, endpoint, ProcMesh from torchstore import MultiProcessStore from torchstore._state_dict_utils import DELIM @@ -37,12 +43,6 @@ from vllm.v1.structured_output import StructuredOutputManager from vllm.worker.worker_base import WorkerWrapperBase -from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh - -from forge.data.sharding import VLLMSharding -from forge.interfaces import Policy as PolicyInterface -from forge.types import ProcessConfig - logger = logging.getLogger(__name__) @@ -329,7 +329,7 @@ class PolicyWorker(ForgeActor): pipeline_parallel_size: int = 1 enforce_eager: bool = False vllm_args: EngineArgs = None - state_dict_key: str = "model_state_dict" + state_dict_key: str = "v0" # This should be the dynamic-version of the sd. def __post_init__(self): """Build vLLM Arguments diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index 4232ca5ca..dc1eeec51 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. +import asyncio import logging import math import os @@ -12,7 +13,14 @@ from dataclasses import dataclass, field, fields import torch +import torchtitan.experiments.forge.train_spec as forge_train_spec + +# from tqdm import tqdm + + +from forge.controller import ForgeActor from monarch.actor import current_rank, current_size, endpoint +from torchstore._state_dict_utils import push_state_dict from torchtitan.config.job_config import ( ActivationCheckpoint, Checkpoint, @@ -30,8 +38,6 @@ from torchtitan.experiments.forge.engine import ForgeEngine from torchtitan.experiments.forge.job_config import ForgeJobConfig -from forge.controller import ForgeActor - logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -89,7 +95,8 @@ def __post_init__(self): os.environ.update(env) @endpoint - async def setup(self): + async def setup(self, store: MultiProcessStore = None): + self.store = store # TODO: update ForgeEngine to not use ForgeJobConfig engine_config = {f.name: getattr(self, f.name) for f in fields(self)} self.engine = ForgeEngine(ForgeJobConfig(**engine_config)) @@ -185,6 +192,18 @@ def train_step(self, batch) -> None: self.engine.lr_schedulers.step() self.current_step += 1 + + # save to torchstore. Hacking in to the Checkpointer's prepped state-dict for now. + # TODOs: + # 1. Figure out if there is a value in calling state_dict_adatpr.to_hf() + # 2. Checkpoint invokes state-dict flattening during dcp_save for [MODEL]. + # May need to replicate the same in this code path. + # 3. Integrate zero-overhead version of push_state_dict. + # 4. Figure out a way to notify the generator app that weights are ready. This beyond the initial integration success. + # 5. Unify CheckpointManager and TorchStore weights save control path. + push_state_dict(self._tstore, self.checkpointer.states, f"v{self.current_step}") + # if self.current_step % self.train_config.val_every_n_steps == 0: + # self.validate() self.engine.checkpointer.save( curr_step=self.current_step, last_step=self.current_step == self.num_training_steps, @@ -262,7 +281,16 @@ def train_step(self, batch) -> None: @endpoint def push_weights(self) -> None: - pass + if self.torchstore is None: + raise Exception("No torchstore configured, error in model publish") + # save to torchstore. Hacking in to the Checkpointer's prepped state-dict for now. + # TODOs: + # 1. Figure out if there is a value in calling state_dict_adatpr.to_hf() + # 2. Checkpoint invokes state-dict flattening during dcp_save for [MODEL]. + # May need to replicate the same in this code path. + push_state_dict( + self.store, self.engine.checkpointer.states, f"v{self.current_step}" + ) @endpoint async def cleanup(self) -> None: diff --git a/tests/integration_tests/test_policy_update.py b/tests/integration_tests/test_policy_update.py index 733abcd21..576ab8d25 100644 --- a/tests/integration_tests/test_policy_update.py +++ b/tests/integration_tests/test_policy_update.py @@ -6,6 +6,7 @@ import os +from forge.actors.trainer import RLTrainer import pytest import pytest_asyncio @@ -17,9 +18,10 @@ from torchstore import MultiProcessStore from torchstore._state_dict_utils import push_state_dict from transformers import AutoModelForCausalLM +from forge.controller import ForgeActor from vllm.utils import get_open_port - +from omegaconf import DictConfig requires_cuda = pytest.mark.skipif( not torch.cuda.is_available(), reason="CUDA not available", @@ -167,22 +169,7 @@ def validate_loaded_tensors_equals_original( f"Successfully validated that all {len(loaded_state_dict)} loaded tensors equal original" ) - -async def run_policy_integration(store, original_state_dict, num_gpus): - """ - Common helper function to test Policy integration with different GPU configurations. - - Args: - store: TorchStore instance - original_state_dict: Original state dict for validation - num_gpus: Number of GPUs to use (1 for single GPU, 2+ for tensor parallel) - test_name: Name for test identification in validation messages - """ - print(f"=== PHASE 2: Testing Policy Integration (GPUs: {num_gpus}) ===") - - state_dict_key = "llama3_8b_state_dict" - - # Set up environment variables for vLLM distributed initialization +async def env_setup_and_get_proc_mesh(num_gpus:int, ) -> proc_mesh.ProcMesh: if num_gpus == 1: # Single GPU setup os.environ.setdefault("MASTER_ADDR", "localhost") @@ -198,20 +185,44 @@ async def run_policy_integration(store, original_state_dict, num_gpus): os.environ["MASTER_ADDR"] = master_addr os.environ["MASTER_PORT"] = master_port print(f"Using MASTER_PORT: {master_port} for tensor parallel Policy") - - rank = int(os.environ.get("RANK", "0")) - - policy_mesh = await proc_mesh( + + return proc_mesh( gpus=num_gpus, env={ "MASTER_ADDR": master_addr, "MASTER_PORT": master_port, }, - ) + ) + + +async def run_rl_trainer_and_push_weights_to_torch_store(proc_mesh,store, cfg:DictConfig, num_gpus) -> ForgeActor: + rl_trainer_actor = await proc_mesh.spawn( "rl_trainer_actor", RLTrainer, cfg=cfg.trainer, + processes=cfg.trainer.pop("processes"), + set_address=True) + await rl_trainer_actor.setup.call(store) + print("RLTrainer setup completed successfully!") + # Push weights to torchstore + await rl_trainer_actor.push_weights() + return rl_trainer_actor + + + +async def run_policy_actor_and_update_weights_from_torch_store(proc_mesh, store, num_gpus) -> ForgeActor: + """ + Common helper function to test Policy integration with different GPU configurations. + + Args: + store: TorchStore instance + original_state_dict: Original state dict for validation + num_gpus: Number of GPUs to use (1 for single GPU, 2+ for tensor parallel) + test_name: Name for test identification in validation messages + """ + print(f"=== PHASE 2: Testing Policy Integration (GPUs: {num_gpus}) ===") + state_dict_key = "llama3_8b_state_dict" # Spawn Policy as a proper Monarch actor - policy = await policy_mesh.spawn( - "policy", + policy_actor = await proc_mesh.spawn( + "policy_actor", Policy, model="meta-llama/Meta-Llama-3.1-8B-Instruct", tensor_parallel_size=num_gpus, @@ -221,24 +232,13 @@ async def run_policy_integration(store, original_state_dict, num_gpus): state_dict_key=state_dict_key, ) - await policy.setup.call(store) + await policy_actor.setup.call(store) print("Setup completed successfully!") print("Calling Policy.update() to load weights from torchstore...") - await policy.update.call() - print("Successfully called Policy.update() to load weights from torchstore!") - - model_params = await policy.get_model_params.call() - loaded_state_dict = ( - model_params._values[0] if hasattr(model_params, "_values") else model_params - ) - print("Successfully got model state dict after update") - - validate_loaded_tensors_equals_original( - loaded_state_dict, original_state_dict, tensor_parallel_size=num_gpus, rank=rank - ) - - print("Test passed! State dict successfully loaded into Policy!") + await policy_actor.update.call() + print("Successfully called Policy.update() to load weights from torchstore!") + return policy_actor @pytest_asyncio.fixture(scope="session") @@ -277,15 +277,32 @@ async def llama3_torchstore_setup(): return store, converted_state_dict +def validate_updated_policy_engine_weights(policy_actor:ForgeActor, original_state_dict, num_gpus:int) -> None: + model_params = await policy_actor.get_model_params.call() + loaded_state_dict = ( + model_params._values[0] if hasattr(model_params, "_values") else model_params + ) + print("Successfully got model state dict after update") + + rank = int(os.environ.get("RANK", "0")) + validate_loaded_tensors_equals_original( + loaded_state_dict, original_state_dict, tensor_parallel_size=num_gpus, rank=rank + ) + print("Test passed! State dict successfully loaded into Policy!") + + + @pytest.mark.asyncio @requires_cuda async def test_llama3_policy_update_single(llama3_torchstore_setup): print("Starting Llama 3 8B torchstore test (single GPU)...") - + num_gpus = 1 + proc_mesh = await env_setup_and_get_proc_mesh(num_gpus) store, original_state_dict = llama3_torchstore_setup - await run_policy_integration(store, original_state_dict, num_gpus=1) - + policy_actor = await run_policy_actor_and_update_weights_from_torch_store(proc_mesh,store,num_gpus=num_gpus) + validate_updated_policy_engine_weights(policy_actor=policy_actor,original_state_dict=original_state_dict, num_gpus=num_gpus) + print( "Single GPU test passed! Llama 3.1 8B-Instruct model successfully loaded into Policy via TorchStore!" ) @@ -295,16 +312,29 @@ async def test_llama3_policy_update_single(llama3_torchstore_setup): @requires_cuda async def test_llama3_policy_update_tp(llama3_torchstore_setup): print("Starting tensor parallel test (load full state dict into sharded model)...") - + num_gpus = 2 if torch.cuda.device_count() < 2: pytest.skip( f"Only {torch.cuda.device_count()} GPU(s) available, need 2+ for tensor parallel" ) - + proc_mesh = await env_setup_and_get_proc_mesh(num_gpus) store, original_state_dict = llama3_torchstore_setup - await run_policy_integration(store, original_state_dict, num_gpus=2) + policy_actor = await run_policy_actor_and_update_weights_from_torch_store(proc_mesh,store,num_gpus=num_gpus) + validate_updated_policy_engine_weights(policy_actor=policy_actor,original_state_dict=original_state_dict, num_gpus=num_gpus) print( "Tensor parallel test passed! Full state dict successfully loaded into tensor parallel model!" ) + + +@pytest.mark.asyncio +@requires_cuda +async def test_llama3_e2e_weight_exchange(llama3_torchstore_setup): + num_gpus = 11 + proc_mesh = await env_setup_and_get_proc_mesh(num_gpus) + store = await MultiProcessStore.create_store() + rl_training_actor = await run_rl_trainer_and_push_weights_to_torch_store(proc_mesh,store,cfg =DictConfig({}), num_gpus=num_gpus) + policy_actor = await run_policy_actor_and_update_weights_from_torch_store(proc_mesh,store,num_gpus=num_gpus) + # Use actor references to dictly access weights for validation purposes. + print("Trained weights are successfully exchanged between RLTrainer and Policy via TorchStore!")