Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@
from typing import Dict, List

import torch

from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh

from forge.data.sharding import VLLMSharding
from forge.interfaces import Policy as PolicyInterface
from forge.types import ProcessConfig
from monarch.actor import current_rank, endpoint, ProcMesh
from torchstore import MultiProcessStore
from torchstore._state_dict_utils import DELIM
Expand All @@ -37,12 +43,6 @@
from vllm.v1.structured_output import StructuredOutputManager
from vllm.worker.worker_base import WorkerWrapperBase

from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh

from forge.data.sharding import VLLMSharding
from forge.interfaces import Policy as PolicyInterface
from forge.types import ProcessConfig


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -329,7 +329,7 @@ class PolicyWorker(ForgeActor):
pipeline_parallel_size: int = 1
enforce_eager: bool = False
vllm_args: EngineArgs = None
state_dict_key: str = "model_state_dict"
state_dict_key: str = "v0" # This should be the dynamic-version of the sd.

def __post_init__(self):
"""Build vLLM Arguments
Expand Down
36 changes: 32 additions & 4 deletions src/forge/actors/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,22 @@
# LICENSE file in the root directory of this source tree.


import asyncio
import logging
import math
import os
from collections.abc import Mapping
from dataclasses import dataclass, field, fields

import torch
import torchtitan.experiments.forge.train_spec as forge_train_spec

# from tqdm import tqdm


from forge.controller import ForgeActor
from monarch.actor import current_rank, current_size, endpoint
from torchstore._state_dict_utils import push_state_dict
from torchtitan.config.job_config import (
ActivationCheckpoint,
Checkpoint,
Expand All @@ -30,8 +38,6 @@
from torchtitan.experiments.forge.engine import ForgeEngine
from torchtitan.experiments.forge.job_config import ForgeJobConfig

from forge.controller import ForgeActor

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

Expand Down Expand Up @@ -89,7 +95,8 @@ def __post_init__(self):
os.environ.update(env)

@endpoint
async def setup(self):
async def setup(self, store: MultiProcessStore = None):
self.store = store
# TODO: update ForgeEngine to not use ForgeJobConfig
engine_config = {f.name: getattr(self, f.name) for f in fields(self)}
self.engine = ForgeEngine(ForgeJobConfig(**engine_config))
Expand Down Expand Up @@ -185,6 +192,18 @@ def train_step(self, batch) -> None:
self.engine.lr_schedulers.step()

self.current_step += 1

# save to torchstore. Hacking in to the Checkpointer's prepped state-dict for now.
# TODOs:
# 1. Figure out if there is a value in calling state_dict_adatpr.to_hf()
# 2. Checkpoint invokes state-dict flattening during dcp_save for [MODEL].
# May need to replicate the same in this code path.
# 3. Integrate zero-overhead version of push_state_dict.
# 4. Figure out a way to notify the generator app that weights are ready. This beyond the initial integration success.
# 5. Unify CheckpointManager and TorchStore weights save control path.
push_state_dict(self._tstore, self.checkpointer.states, f"v{self.current_step}")
# if self.current_step % self.train_config.val_every_n_steps == 0:
# self.validate()
self.engine.checkpointer.save(
curr_step=self.current_step,
last_step=self.current_step == self.num_training_steps,
Expand Down Expand Up @@ -262,7 +281,16 @@ def train_step(self, batch) -> None:

@endpoint
def push_weights(self) -> None:
pass
if self.torchstore is None:
raise Exception("No torchstore configured, error in model publish")
# save to torchstore. Hacking in to the Checkpointer's prepped state-dict for now.
# TODOs:
# 1. Figure out if there is a value in calling state_dict_adatpr.to_hf()
# 2. Checkpoint invokes state-dict flattening during dcp_save for [MODEL].
# May need to replicate the same in this code path.
push_state_dict(
self.store, self.engine.checkpointer.states, f"v{self.current_step}"
)

@endpoint
async def cleanup(self) -> None:
Expand Down
120 changes: 75 additions & 45 deletions tests/integration_tests/test_policy_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import os

from forge.actors.trainer import RLTrainer
import pytest
import pytest_asyncio

Expand All @@ -17,9 +18,10 @@
from torchstore import MultiProcessStore
from torchstore._state_dict_utils import push_state_dict
from transformers import AutoModelForCausalLM
from forge.controller import ForgeActor

from vllm.utils import get_open_port

from omegaconf import DictConfig
requires_cuda = pytest.mark.skipif(
not torch.cuda.is_available(),
reason="CUDA not available",
Expand Down Expand Up @@ -167,22 +169,7 @@ def validate_loaded_tensors_equals_original(
f"Successfully validated that all {len(loaded_state_dict)} loaded tensors equal original"
)


async def run_policy_integration(store, original_state_dict, num_gpus):
"""
Common helper function to test Policy integration with different GPU configurations.

Args:
store: TorchStore instance
original_state_dict: Original state dict for validation
num_gpus: Number of GPUs to use (1 for single GPU, 2+ for tensor parallel)
test_name: Name for test identification in validation messages
"""
print(f"=== PHASE 2: Testing Policy Integration (GPUs: {num_gpus}) ===")

state_dict_key = "llama3_8b_state_dict"

# Set up environment variables for vLLM distributed initialization
async def env_setup_and_get_proc_mesh(num_gpus:int, ) -> proc_mesh.ProcMesh:
if num_gpus == 1:
# Single GPU setup
os.environ.setdefault("MASTER_ADDR", "localhost")
Expand All @@ -198,20 +185,44 @@ async def run_policy_integration(store, original_state_dict, num_gpus):
os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = master_port
print(f"Using MASTER_PORT: {master_port} for tensor parallel Policy")

rank = int(os.environ.get("RANK", "0"))

policy_mesh = await proc_mesh(

return proc_mesh(
gpus=num_gpus,
env={
"MASTER_ADDR": master_addr,
"MASTER_PORT": master_port,
},
)
)


async def run_rl_trainer_and_push_weights_to_torch_store(proc_mesh,store, cfg:DictConfig, num_gpus) -> ForgeActor:
rl_trainer_actor = await proc_mesh.spawn( "rl_trainer_actor", RLTrainer, cfg=cfg.trainer,
processes=cfg.trainer.pop("processes"),
set_address=True)
await rl_trainer_actor.setup.call(store)
print("RLTrainer setup completed successfully!")
# Push weights to torchstore
await rl_trainer_actor.push_weights()
return rl_trainer_actor



async def run_policy_actor_and_update_weights_from_torch_store(proc_mesh, store, num_gpus) -> ForgeActor:
"""
Common helper function to test Policy integration with different GPU configurations.

Args:
store: TorchStore instance
original_state_dict: Original state dict for validation
num_gpus: Number of GPUs to use (1 for single GPU, 2+ for tensor parallel)
test_name: Name for test identification in validation messages
"""
print(f"=== PHASE 2: Testing Policy Integration (GPUs: {num_gpus}) ===")
state_dict_key = "llama3_8b_state_dict"

# Spawn Policy as a proper Monarch actor
policy = await policy_mesh.spawn(
"policy",
policy_actor = await proc_mesh.spawn(
"policy_actor",
Policy,
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
tensor_parallel_size=num_gpus,
Expand All @@ -221,24 +232,13 @@ async def run_policy_integration(store, original_state_dict, num_gpus):
state_dict_key=state_dict_key,
)

await policy.setup.call(store)
await policy_actor.setup.call(store)
print("Setup completed successfully!")

print("Calling Policy.update() to load weights from torchstore...")
await policy.update.call()
print("Successfully called Policy.update() to load weights from torchstore!")

model_params = await policy.get_model_params.call()
loaded_state_dict = (
model_params._values[0] if hasattr(model_params, "_values") else model_params
)
print("Successfully got model state dict after update")

validate_loaded_tensors_equals_original(
loaded_state_dict, original_state_dict, tensor_parallel_size=num_gpus, rank=rank
)

print("Test passed! State dict successfully loaded into Policy!")
await policy_actor.update.call()
print("Successfully called Policy.update() to load weights from torchstore!")
return policy_actor


@pytest_asyncio.fixture(scope="session")
Expand Down Expand Up @@ -277,15 +277,32 @@ async def llama3_torchstore_setup():
return store, converted_state_dict


def validate_updated_policy_engine_weights(policy_actor:ForgeActor, original_state_dict, num_gpus:int) -> None:
model_params = await policy_actor.get_model_params.call()
loaded_state_dict = (
model_params._values[0] if hasattr(model_params, "_values") else model_params
)
print("Successfully got model state dict after update")

rank = int(os.environ.get("RANK", "0"))
validate_loaded_tensors_equals_original(
loaded_state_dict, original_state_dict, tensor_parallel_size=num_gpus, rank=rank
)
print("Test passed! State dict successfully loaded into Policy!")



@pytest.mark.asyncio
@requires_cuda
async def test_llama3_policy_update_single(llama3_torchstore_setup):
print("Starting Llama 3 8B torchstore test (single GPU)...")

num_gpus = 1
proc_mesh = await env_setup_and_get_proc_mesh(num_gpus)
store, original_state_dict = llama3_torchstore_setup

await run_policy_integration(store, original_state_dict, num_gpus=1)

policy_actor = await run_policy_actor_and_update_weights_from_torch_store(proc_mesh,store,num_gpus=num_gpus)
validate_updated_policy_engine_weights(policy_actor=policy_actor,original_state_dict=original_state_dict, num_gpus=num_gpus)

print(
"Single GPU test passed! Llama 3.1 8B-Instruct model successfully loaded into Policy via TorchStore!"
)
Expand All @@ -295,16 +312,29 @@ async def test_llama3_policy_update_single(llama3_torchstore_setup):
@requires_cuda
async def test_llama3_policy_update_tp(llama3_torchstore_setup):
print("Starting tensor parallel test (load full state dict into sharded model)...")

num_gpus = 2
if torch.cuda.device_count() < 2:
pytest.skip(
f"Only {torch.cuda.device_count()} GPU(s) available, need 2+ for tensor parallel"
)

proc_mesh = await env_setup_and_get_proc_mesh(num_gpus)
store, original_state_dict = llama3_torchstore_setup

await run_policy_integration(store, original_state_dict, num_gpus=2)
policy_actor = await run_policy_actor_and_update_weights_from_torch_store(proc_mesh,store,num_gpus=num_gpus)
validate_updated_policy_engine_weights(policy_actor=policy_actor,original_state_dict=original_state_dict, num_gpus=num_gpus)

print(
"Tensor parallel test passed! Full state dict successfully loaded into tensor parallel model!"
)


@pytest.mark.asyncio
@requires_cuda
async def test_llama3_e2e_weight_exchange(llama3_torchstore_setup):
num_gpus = 11
proc_mesh = await env_setup_and_get_proc_mesh(num_gpus)
store = await MultiProcessStore.create_store()
rl_training_actor = await run_rl_trainer_and_push_weights_to_torch_store(proc_mesh,store,cfg =DictConfig({}), num_gpus=num_gpus)
policy_actor = await run_policy_actor_and_update_weights_from_torch_store(proc_mesh,store,num_gpus=num_gpus)
# Use actor references to dictly access weights for validation purposes.
print("Trained weights are successfully exchanged between RLTrainer and Policy via TorchStore!")
Loading