Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 54 additions & 4 deletions src/forge/actors/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,16 @@
import logging
import math
import os
import time
from collections.abc import Mapping
from dataclasses import dataclass, field, fields
from typing import Any, Dict

import torch
import torchstore as ts
from monarch.actor import current_rank, current_size, endpoint
from torch.distributed.checkpoint._nested_dict import flatten_state_dict
from torchstore.state_dict_utils import DELIM
from torchtitan.config.job_config import (
ActivationCheckpoint,
Checkpoint,
Expand All @@ -25,7 +30,6 @@
Parallelism,
Training,
)

from torchtitan.distributed import utils as dist_utils
from torchtitan.experiments.forge.engine import ForgeEngine
from torchtitan.experiments.forge.job_config import ForgeJobConfig
Expand All @@ -50,6 +54,7 @@ class RLTrainer(ForgeActor):
compile: Compile = field(default_factory=Compile)
float8: Float8 = field(default_factory=Float8)
comm: Comm = field(default_factory=Comm)
state_dict_key: str = "model_state_dict"

def __post_init__(self):
"""Initializes config types and env variables.
Expand All @@ -68,7 +73,7 @@ def __post_init__(self):
f"{f.name} should be a {f.type} type or a dict like object"
)

self.current_step = 0
self.current_step = 1 # fragile contract.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we starting at 1? Also, we probably want a todo to update this from the checkpoint

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because policy engine starting at 1. Lets keep this fragile contract as it is. The true version has to come from a config or external book-keeping entity.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we can change this without risking breaking checkpoint expectations from titan side. I'd rather just use a separate variable in the trainer for "checkpoint name" (can be a property that's just current_step + 1 for now). This could also be passed in from the controller which would be better.

self.num_training_steps = self.training.steps
self.gradient_accumulation_steps = 1
self.rank = current_rank().rank
Expand Down Expand Up @@ -261,15 +266,60 @@ def train_step(self, batch) -> None:
# return {"loss": avg_loss, "groups_processed": num_groups_processed}

@endpoint
def push_weights(self) -> None:
pass
async def push_weights(self, policy_version: int) -> None:
# Save to torchstore. Hacking in to the Checkpointer's prepped state-dict for now.
# TODO:
# 1. Checkpoint invokes state-dict flattening during dcp_save for [MODEL].
# May need to replicate the same in this code path.
# 2. Unify CheckpointManager and TorchStore weights save control path.
if "model" not in self.engine.checkpointer.states:
raise RuntimeError("Model state not found in checkpointer state")
sd = self.engine.checkpointer.states["model"].state_dict()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this coming from? When you call this, does it create the sd right then or did it have to be saved in the train step earlier? Does it return the sd on GPU or CPU? Also does it handle blocking the trainer from updating the weights while it's getting them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm accessing the module state-dict prepped by torch.titan as part of checkpoint save.

  1. This is a in-memory state-dict. ( Tensor/DTensor).
  2. It returns tensors with original storage. Means GPU/UVM backed tensors.

Also does it handle blocking the trainer from updating the weights while it's getting them?

Hmm.. it does not block the trainer. However, ForgeEngine drive the trainer using train_step. Therefore, there is no race-conditions with current code.

There is improvements to be made to this code. In the ideal case;

  1. the state-dict get prepped for weight-exchange and checkpoint save purposes.
  2. Once the initial state-dict prep we can cache the prepped state-dict for later iterations of the training steps for efficiency reasons ( if there is opportunity).
  3. We move all the model weights and optimizer state to torchstore.
  4. Policy engine (only) lookup the model-weights from torchstore
  5. Async checkpointing upload lookups model-weights and optimizer states for uploading in to remote persistent storage.

We don't have all the piece right now. But tapping in to checkpoint state-dict is the right thing to do as the first step.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess you're right that it should be mostly safe since we control the update from the controller. But since they're async calls they could be overlapped so we'll have to be careful for now.

flattened_state_dict, _ = flatten_state_dict(sd)
if self.engine.checkpointer.sd_adapter is None:
raise RuntimeError(
"Trying to save checkpoint in HF safetensors format, but sd_adapter is not provided."
)
hf_state_dict = self.engine.checkpointer.sd_adapter.to_hf(flattened_state_dict)
# TODO: Figure out how to gracefully handle which model to-vLLM conversion is needed
vllm_ready_hf_sd = _qwen3_hf_to_vllm(sd=hf_state_dict, num_layers=28)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious: what is the reasoning for doing this at the learner and not at the generator? The trainer just pushes it weights and the generator can based on it's implementation (vLLM, sglang etc.) can modify the sd.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an temp thing I did. It will be moved to generator sd loading + it will be moved out of the trainer/generator critical path based on efficiency numbers.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To add to Pradeep's answer, vLLM already handles it's own hf -> vllm mapping. The only reason we've recreated it is so we can add a shaded loading solution which we want to eventually upstream. It will be on the generator side like he said.

key = f"{self.state_dict_key}{DELIM}{policy_version}"
start_time = time.time()
await ts.put_state_dict(state_dict=vllm_ready_hf_sd, key=key)
Copy link
Contributor

@Ritesh1905 Ritesh1905 Sep 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: are there benfits to doing this at the state dict level and not at the key level where we could parallelize the individual put operation per key?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now, the benefit is simplicity.

Eventually I imagine we will want to do this on a per-slice level.

cc @LucasLLC

end_time = time.time()
self.logger.debug(
f"Pushed weights to {key} in {end_time - start_time:.2f} seconds"
)

@endpoint
async def cleanup(self) -> None:
if self.engine.checkpointer:
self.engine.checkpointer.close()


def llama3_hf_to_vllm(hf_trainer_sd: Dict[str, Any]) -> Dict[str, Any]:
"""
Convert HF formatted state-dict to VLLM format. Ideally this conversion
should not be needed, if the VLLM fully supports the loading of
HF formatted llama3 model.
"""
for i in range(32): # number of layers in llama3 8B model.
prefix = f"model.layers.{i}."
# QKV fusion
q = hf_trainer_sd.pop(prefix + "self_attn.q_proj.weight")
k = hf_trainer_sd.pop(prefix + "self_attn.k_proj.weight")
v = hf_trainer_sd.pop(prefix + "self_attn.v_proj.weight")
hf_trainer_sd[prefix + "self_attn.qkv_proj.weight"] = torch.cat(
[q, k, v], dim=0
)
# MLP gate_up_proj fusion
gate = hf_trainer_sd.pop(prefix + "mlp.gate_proj.weight")
up = hf_trainer_sd.pop(prefix + "mlp.up_proj.weight")
hf_trainer_sd[prefix + "mlp.gate_up_proj.weight"] = torch.cat([gate, up], dim=0)

return hf_trainer_sd


def _qwen3_hf_to_vllm(
sd: dict[str, torch.Tensor], num_layers: int
) -> dict[str, torch.Tensor]:
Expand Down
177 changes: 81 additions & 96 deletions tests/integration_tests/test_policy_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,21 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
from typing import Dict, Tuple

import pytest
import pytest_asyncio

import torch

import torchstore as ts
from forge.actors.policy import EngineConfig, Policy, SamplingConfig

from forge.actors.trainer import RLTrainer
from forge.controller.service import ServiceConfig, spawn_service
from forge.data.sharding import VLLMSharding
from torchstore import MultiProcessStore
from torchstore._state_dict_utils import push_state_dict

from omegaconf import DictConfig, OmegaConf
from transformers import AutoModelForCausalLM

requires_cuda = pytest.mark.skipif(
Expand All @@ -24,6 +27,10 @@
)


logger: logging.Logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def convert_state_dict(saved_sd):
"""
Convert transformers state dict to vLLM format.
Expand Down Expand Up @@ -70,16 +77,6 @@ def convert_state_dict(saved_sd):
return load_sd


async def save_state_dict(
store: MultiProcessStore, state_dict: dict[str, torch.Tensor], key_prefix: str
):
print(f"Saving {len(state_dict)} tensors")

await push_state_dict(store, state_dict, key_prefix)

print(f"Successfully saved {len(state_dict)} tensors")


def calculate_expected_shard(
full_tensor: torch.Tensor,
param_name: str,
Expand Down Expand Up @@ -127,8 +124,6 @@ def validate_loaded_tensors_equals_original(
For tensor parallel cases, instead of gathering sharded tensors, we shard
the original tensor and compare it with the loaded shard.
"""
validation_errors = []

for param_name, loaded_tensor in loaded_state_dict.items():
if param_name in original_state_dict:
expected_tensor = original_state_dict[param_name]
Expand All @@ -145,22 +140,27 @@ def validate_loaded_tensors_equals_original(
else:
tensor_to_compare = expected_tensor.cpu().float()

# Training trainer emitted and loaded tensors are of type bfloat16,
# where as a HF model loaded(expected) tensor has type float16.
if not torch.allclose(
loaded_tensor.float(),
tensor_to_compare,
rtol=1e-5,
atol=1e-8,
rtol=1e-2,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this new rtol/atol expected? question for @pbontrager

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is never great, but given the bf16/fp16 comments I could see that. This is also an allclose and not a comparison of the mean so we should be safe here. If we can load the hf side with bf16 instead of fp16 we might be able to regain the tighter tolerance.

atol=1e-3,
):
validation_errors.append(
logger.warning(
f"Loaded tensor {param_name} does not equal original.\n"
f"dtype = {loaded_tensor.dtype} vs {expected_tensor.dtype}\n"
f"shape= {loaded_tensor.shape} vs {expected_tensor.shape}\n,"
f"values = {loaded_tensor.copy()} vs {expected_tensor.copy()}"
)
raise ValueError(
f"Loaded tensor {param_name} does not equal original "
f"(shapes: loaded={loaded_tensor.shape}, expected={tensor_to_compare.shape})"
)
else:
print(f"Loaded tensor {param_name} correctly validated")

if validation_errors:
raise ValueError(f"Validation failed: {validation_errors}")

print(
f"Successfully validated that all {len(loaded_state_dict)} loaded tensors equal original"
)
Expand Down Expand Up @@ -191,47 +191,13 @@ def get_configs(worker_size: int, model_name: str) -> Tuple[Dict, ServiceConfig]
return policy_config, service_config


async def run_policy_integration(
store, original_state_dict, worker_size
) -> Dict[str, torch.Tensor]:
"""
Common helper function to test Policy integration with different GPU configurations.

Args:
store: TorchStore instance
original_state_dict: Original state dict for validation
num_gpus: Number of GPUs to use (1 for single GPU, 2+ for tensor parallel)
"""
print(f"=== PHASE 2: Testing Policy Integration (Workers: {worker_size}) ===")

policy_config, service_config = get_configs(
worker_size=1, model_name="meta-llama/Llama-3.1-8B-Instruct"
)
policy = await spawn_service(service_config, Policy, store=store, **policy_config)

# Policy engine start with default version 0 that gets incremented.
print("Calling Policy.update() to load weights from torchstore...")
await policy.update_weights.call()
print(
"Successfully called Policy.update_weights() to load weights from torchstore!"
)
# We get the result as a list.
results = await policy._get_model_params.call()
assert len(results) == 1
print("Successfully got model state dict after update")
return results[0]


@pytest_asyncio.fixture(scope="session")
async def llama3_torchstore_setup():
async def setup_test():
"""
Pytest fixture to load Llama 3.1 8B-Instruct and write state dict to torchstore.
Uses session scope so it's only called once when both tests are run.
Pytest fixture to load Llama 3.1 8B-Instruct. We use the loaded state dict
as the SOT for validation. Uses session scope so it's only called once
across UT.
"""
print("=== PHASE 1: Writing Llama 3.1 8B-Instruct to TorchStore ===")

store = await MultiProcessStore.create_store()

model_path = "meta-llama/Meta-Llama-3.1-8B-Instruct"

# Load the model from local path - using device_map="auto" for efficient loading
Expand All @@ -244,55 +210,74 @@ async def llama3_torchstore_setup():

original_state_dict = model.state_dict()
print(f"Original state dict has {len(original_state_dict)} parameters")
hf_state_dict = convert_state_dict(original_state_dict)
print(f"Converted state dict has {len(hf_state_dict)} parameters")

print("Converting transformers state dict to vLLM format...")
converted_state_dict = convert_state_dict(original_state_dict)
print(f"Converted state dict has {len(converted_state_dict)} parameters")
return hf_state_dict

state_dict_key = "model_state_dict/1" # {app_namespace}/{version}
await save_state_dict(store, converted_state_dict, state_dict_key)
print(
f"Successfully wrote converted state dict to torchstore with key: {state_dict_key}"

async def run_rl_trainer(worker_size) -> None:
"""
Spawn the RL trainer
Args:
worker_size: Number of workers/procs.
"""
cfg: DictConfig = OmegaConf.load("apps/rl/llama3_8b.yaml")
rl_trainer = await spawn_service(
ServiceConfig(procs_per_replica=worker_size, with_gpus=True, num_replicas=1),
RLTrainer,
**cfg.trainer,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we should hardcode the trainer config here rather than load from apps/rl/llama3_8b.yaml

)
# Push the weights to torchstore
await rl_trainer.push_weights.choose(policy_version=0)

return store, converted_state_dict

async def run_policy_integration(worker_size) -> Dict[str, torch.Tensor]:
"""
Launch the policy service.

@pytest.mark.asyncio
@requires_cuda
async def test_llama3_policy_update_single(llama3_torchstore_setup):
print("Starting Llama 3 8B torchstore test (single GPU)...")
Args:
store: TorchStore instance
worker_size: Number of workers/procs (2+ for tensor parallel)
"""
print(f"=== PHASE 2: Launching Policy Engine (Workers: {worker_size}) ===")

store, original_state_dict = llama3_torchstore_setup
policy_config, service_config = get_configs(
worker_size=worker_size, model_name="meta-llama/Llama-3.1-8B-Instruct"
)
policy = await spawn_service(service_config, Policy, **policy_config)

loaded_state_dict = await run_policy_integration(
store, original_state_dict, worker_size=1
# Policy engine start with default version 0 that gets incremented.
print("Calling Policy.update() to load weights from torchstore...")
await policy.update_weights.call()
print(
"Successfully called Policy.update_weights() to load weights from torchstore!"
)
results = await policy._get_model_params.call()
assert len(results) == 1
print("Successfully got model state dict after update")
return results[0]


@pytest.mark.asyncio
@requires_cuda
async def test_llama3_policy_update_single(setup_test):
"""
1. Loads weights from HF model into in-memory state-dict (source of truth)
2. Initializes RLTrainer, make the weights available in torchstore.
3. Initializes Policy, and calls update_weights() to load weights from torchstore.
4. Validate the policy weights against source of truth.
"""
logger.info("Starting Llama 3 8B torchstore test (single GPU)...")
await ts.initialize()
expected_state_dict = setup_test
await run_rl_trainer(worker_size=1)
loaded_state_dict = await run_policy_integration(worker_size=1)

# validating for single resource case.
validate_loaded_tensors_equals_original(
loaded_state_dict, original_state_dict, tensor_parallel_size=0, rank=0
loaded_state_dict, expected_state_dict, tensor_parallel_size=0, rank=0
)

print(
logger.info(
"Single GPU test passed! Llama 3.1 8B-Instruct model successfully loaded into Policy via TorchStore!"
)


# @pytest.mark.asyncio
# @requires_cuda
# async def test_llama3_policy_update_tp(llama3_torchstore_setup):
# print("Starting tensor parallel test (load full state dict into sharded model)...")
#
# if torch.cuda.device_count() < 2:
# pytest.skip(
# f"Only {torch.cuda.device_count()} GPU(s) available, need 2+ for tensor parallel"
# )
#
# store, original_state_dict = llama3_torchstore_setup
#
# await run_policy_integration(store, original_state_dict, num_gpus=2)
#
# print(
# "Tensor parallel test passed! Full state dict successfully loaded into tensor parallel model!"
# )
Loading