-
Notifications
You must be signed in to change notification settings - Fork 18
Publishing weights in to torchstore from RLTrainer and getting them from policy engine. #138
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b8632e5
60ca8c5
d9e1e84
66b969d
bc58196
5f7cf3c
32ac7f3
fe0d924
2793033
856661f
264ccc8
45eb52d
0d11d36
dd555e5
bcc0038
5a062be
1656d5d
825bd1c
a8b4514
8bc468d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,12 +8,18 @@ | |
import logging | ||
import math | ||
import os | ||
import time | ||
from collections.abc import Mapping | ||
from dataclasses import dataclass, field, fields | ||
from typing import Callable | ||
|
||
import torch | ||
import torchstore as ts | ||
|
||
from monarch.actor import current_rank, current_size, endpoint | ||
from torch import Tensor | ||
from torch.distributed.checkpoint._nested_dict import flatten_state_dict | ||
from torchstore.state_dict_utils import DELIM | ||
from torchtitan.config.job_config import ( | ||
ActivationCheckpoint, | ||
Checkpoint, | ||
|
@@ -26,7 +32,6 @@ | |
Parallelism, | ||
Training, | ||
) | ||
|
||
from torchtitan.distributed import utils as dist_utils | ||
from torchtitan.experiments.forge.engine import ForgeEngine | ||
from torchtitan.experiments.forge.job_config import ForgeJobConfig | ||
|
@@ -53,6 +58,7 @@ class RLTrainer(ForgeActor): | |
float8: Float8 = field(default_factory=Float8) | ||
comm: Comm = field(default_factory=Comm) | ||
loss: Callable = lambda logits, **targets: logits | ||
state_dict_key: str = "model_state_dict" | ||
|
||
def __post_init__(self): | ||
"""Initializes config types and env variables. | ||
|
@@ -71,7 +77,7 @@ def __post_init__(self): | |
f"{f.name} should be a {f.type} type or a dict like object" | ||
) | ||
|
||
self.current_step = 0 | ||
self.current_step = 1 # fragile contract. | ||
self.num_training_steps = self.training.steps | ||
self.gradient_accumulation_steps = 1 | ||
self.rank = current_rank().rank | ||
|
@@ -95,7 +101,8 @@ def __post_init__(self): | |
async def setup(self): | ||
# TODO: update ForgeEngine to not use ForgeJobConfig | ||
engine_config = {f.name: getattr(self, f.name) for f in fields(self)} | ||
engine_config.pop("loss") # Not part of job config | ||
for key in {"loss", "state_dict_key"}: | ||
engine_config.pop(key) # Not part of job config | ||
self.engine = ForgeEngine(ForgeJobConfig(**engine_config)) | ||
self.engine.checkpointer.load(step=self.current_step) | ||
self.engine.optimizers.zero_grad() | ||
|
@@ -197,8 +204,30 @@ def train_step( | |
return {"loss": loss.item()} | ||
|
||
@endpoint | ||
def push_weights(self) -> None: | ||
pass | ||
async def push_weights(self, policy_version: int) -> None: | ||
# Save to torchstore. Hacking in to the Checkpointer's prepped state-dict for now. | ||
# TODO: | ||
# 1. Checkpoint invokes state-dict flattening during dcp_save for [MODEL]. | ||
# May need to replicate the same in this code path. | ||
# 2. Unify CheckpointManager and TorchStore weights save control path. | ||
if "model" not in self.engine.checkpointer.states: | ||
raise RuntimeError("Model state not found in checkpointer state") | ||
sd = self.engine.checkpointer.states["model"].state_dict() | ||
|
||
flattened_state_dict, _ = flatten_state_dict(sd) | ||
if self.engine.checkpointer.sd_adapter is None: | ||
raise RuntimeError( | ||
"Trying to save checkpoint in HF safetensors format, but sd_adapter is not provided." | ||
) | ||
hf_state_dict = self.engine.checkpointer.sd_adapter.to_hf(flattened_state_dict) | ||
# TODO: Figure out how to gracefully handle which model to-vLLM conversion is needed | ||
vllm_ready_hf_sd = _qwen3_hf_to_vllm(sd=hf_state_dict, num_layers=28) | ||
|
||
key = f"{self.state_dict_key}{DELIM}{policy_version}" | ||
start_time = time.time() | ||
await ts.put_state_dict(state_dict=vllm_ready_hf_sd, key=key) | ||
|
||
end_time = time.time() | ||
self.logger.debug( | ||
f"Pushed weights to {key} in {end_time - start_time:.2f} seconds" | ||
) | ||
|
||
@endpoint | ||
async def cleanup(self) -> None: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,24 +4,31 @@ | |
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import Dict, Tuple | ||
import logging | ||
from typing import Callable | ||
|
||
import pytest | ||
import pytest_asyncio | ||
|
||
import torch | ||
|
||
import torchstore as ts | ||
from forge.actors.policy import EngineConfig, Policy, SamplingConfig | ||
|
||
from forge.actors.trainer import RLTrainer | ||
from forge.controller.service import ServiceConfig, spawn_service | ||
from forge.data.sharding import VLLMSharding | ||
from torchstore import MultiProcessStore | ||
from torchstore._state_dict_utils import push_state_dict | ||
|
||
from transformers import AutoModelForCausalLM | ||
|
||
requires_cuda = pytest.mark.skipif( | ||
not torch.cuda.is_available(), | ||
reason="CUDA not available", | ||
) | ||
from forge.actors.trainer import _qwen3_hf_to_vllm | ||
from huggingface_hub import snapshot_download | ||
|
||
logger: logging.Logger = logging.getLogger(__name__) | ||
logger.setLevel(logging.INFO) | ||
|
||
|
||
def convert_state_dict(saved_sd): | ||
|
@@ -70,16 +77,6 @@ def convert_state_dict(saved_sd): | |
return load_sd | ||
|
||
|
||
async def save_state_dict( | ||
store: MultiProcessStore, state_dict: dict[str, torch.Tensor], key_prefix: str | ||
): | ||
print(f"Saving {len(state_dict)} tensors") | ||
|
||
await push_state_dict(store, state_dict, key_prefix) | ||
|
||
print(f"Successfully saved {len(state_dict)} tensors") | ||
|
||
|
||
def calculate_expected_shard( | ||
full_tensor: torch.Tensor, | ||
param_name: str, | ||
|
@@ -127,8 +124,6 @@ def validate_loaded_tensors_equals_original( | |
For tensor parallel cases, instead of gathering sharded tensors, we shard | ||
the original tensor and compare it with the loaded shard. | ||
""" | ||
validation_errors = [] | ||
|
||
for param_name, loaded_tensor in loaded_state_dict.items(): | ||
if param_name in original_state_dict: | ||
expected_tensor = original_state_dict[param_name] | ||
|
@@ -145,154 +140,116 @@ def validate_loaded_tensors_equals_original( | |
else: | ||
tensor_to_compare = expected_tensor.cpu().float() | ||
|
||
# Training trainer emitted and loaded tensors are of type bfloat16, | ||
# where as a HF model loaded(expected) tensor has type float16. | ||
if not torch.allclose( | ||
loaded_tensor.float(), | ||
tensor_to_compare, | ||
rtol=1e-5, | ||
atol=1e-8, | ||
rtol=1e-2, | ||
|
||
atol=1e-3, | ||
): | ||
validation_errors.append( | ||
logger.warning( | ||
f"Loaded tensor {param_name} does not equal original.\n" | ||
f"dtype = {loaded_tensor.dtype} vs {expected_tensor.dtype}\n" | ||
f"shape= {loaded_tensor.shape} vs {expected_tensor.shape}\n," | ||
f"values = {loaded_tensor} vs {expected_tensor}" | ||
) | ||
raise ValueError( | ||
f"Loaded tensor {param_name} does not equal original " | ||
f"(shapes: loaded={loaded_tensor.shape}, expected={tensor_to_compare.shape})" | ||
) | ||
else: | ||
print(f"Loaded tensor {param_name} correctly validated") | ||
|
||
if validation_errors: | ||
raise ValueError(f"Validation failed: {validation_errors}") | ||
|
||
print( | ||
f"Successfully validated that all {len(loaded_state_dict)} loaded tensors equal original" | ||
) | ||
|
||
|
||
def get_configs(worker_size: int, model_name: str) -> Tuple[Dict, ServiceConfig]: | ||
|
||
def get_configs(worker_size: int, model_name: str) -> tuple[dict, ServiceConfig]: | ||
engine_config = EngineConfig( | ||
model=model_name, | ||
tensor_parallel_size=worker_size, | ||
pipeline_parallel_size=1, | ||
enforce_eager=True, | ||
) | ||
|
||
sampling_config = SamplingConfig( | ||
n=3, | ||
guided_decoding=True, | ||
) | ||
|
||
policy_config = { | ||
"engine_config": engine_config, | ||
"sampling_config": sampling_config, | ||
} | ||
service_config = ServiceConfig( | ||
procs_per_replica=worker_size, num_replicas=1, with_gpus=True | ||
) | ||
|
||
return policy_config, service_config | ||
|
||
|
||
async def run_policy_integration( | ||
store, original_state_dict, worker_size | ||
) -> Dict[str, torch.Tensor]: | ||
""" | ||
Common helper function to test Policy integration with different GPU configurations. | ||
|
||
Args: | ||
store: TorchStore instance | ||
original_state_dict: Original state dict for validation | ||
num_gpus: Number of GPUs to use (1 for single GPU, 2+ for tensor parallel) | ||
""" | ||
print(f"=== PHASE 2: Testing Policy Integration (Workers: {worker_size}) ===") | ||
|
||
policy_config, service_config = get_configs( | ||
worker_size=1, model_name="meta-llama/Llama-3.1-8B-Instruct" | ||
) | ||
policy = await spawn_service(service_config, Policy, store=store, **policy_config) | ||
|
||
# Policy engine start with default version 0 that gets incremented. | ||
print("Calling Policy.update() to load weights from torchstore...") | ||
await policy.update_weights.call() | ||
print( | ||
"Successfully called Policy.update_weights() to load weights from torchstore!" | ||
) | ||
# We get the result as a list. | ||
results = await policy._get_model_params.call() | ||
assert len(results) == 1 | ||
print("Successfully got model state dict after update") | ||
return results[0] | ||
|
||
|
||
@pytest_asyncio.fixture(scope="session") | ||
async def llama3_torchstore_setup(): | ||
""" | ||
Pytest fixture to load Llama 3.1 8B-Instruct and write state dict to torchstore. | ||
Uses session scope so it's only called once when both tests are run. | ||
""" | ||
print("=== PHASE 1: Writing Llama 3.1 8B-Instruct to TorchStore ===") | ||
|
||
store = await MultiProcessStore.create_store() | ||
|
||
model_path = "meta-llama/Meta-Llama-3.1-8B-Instruct" | ||
|
||
# Load the model from local path - using device_map="auto" for efficient loading | ||
model = AutoModelForCausalLM.from_pretrained( | ||
model_path, | ||
torch_dtype=torch.float16, # Use half precision to save memory | ||
device_map="auto", | ||
trust_remote_code=True, | ||
) | ||
|
||
original_state_dict = model.state_dict() | ||
print(f"Original state dict has {len(original_state_dict)} parameters") | ||
|
||
print("Converting transformers state dict to vLLM format...") | ||
converted_state_dict = convert_state_dict(original_state_dict) | ||
print(f"Converted state dict has {len(converted_state_dict)} parameters") | ||
|
||
state_dict_key = "model_state_dict/1" # {app_namespace}/{version} | ||
await save_state_dict(store, converted_state_dict, state_dict_key) | ||
print( | ||
f"Successfully wrote converted state dict to torchstore with key: {state_dict_key}" | ||
) | ||
|
||
return store, converted_state_dict | ||
|
||
|
||
@pytest.mark.asyncio | ||
@requires_cuda | ||
async def test_llama3_policy_update_single(llama3_torchstore_setup): | ||
print("Starting Llama 3 8B torchstore test (single GPU)...") | ||
|
||
store, original_state_dict = llama3_torchstore_setup | ||
|
||
loaded_state_dict = await run_policy_integration( | ||
store, original_state_dict, worker_size=1 | ||
) | ||
|
||
# validating for single resource case. | ||
validate_loaded_tensors_equals_original( | ||
loaded_state_dict, original_state_dict, tensor_parallel_size=0, rank=0 | ||
) | ||
|
||
print( | ||
"Single GPU test passed! Llama 3.1 8B-Instruct model successfully loaded into Policy via TorchStore!" | ||
) | ||
|
||
|
||
# @pytest.mark.asyncio | ||
# @requires_cuda | ||
# async def test_llama3_policy_update_tp(llama3_torchstore_setup): | ||
# print("Starting tensor parallel test (load full state dict into sharded model)...") | ||
# | ||
# if torch.cuda.device_count() < 2: | ||
# pytest.skip( | ||
# f"Only {torch.cuda.device_count()} GPU(s) available, need 2+ for tensor parallel" | ||
# ) | ||
# | ||
# store, original_state_dict = llama3_torchstore_setup | ||
# | ||
# await run_policy_integration(store, original_state_dict, num_gpus=2) | ||
# | ||
# print( | ||
# "Tensor parallel test passed! Full state dict successfully loaded into tensor parallel model!" | ||
# ) | ||
class TestWeightSync: | ||
"""Tests for weight sync between trainer and policy. Currently hardcoded to Qwen3-1.7B.""" | ||
|
||
model = "Qwen/Qwen3-1.7B" | ||
to_vllm_fn: Callable = _qwen3_hf_to_vllm | ||
num_layers = 28 | ||
|
||
@pytest_asyncio.fixture | ||
def trainer_cfg(self): | ||
cached_dir = snapshot_download(repo_id=self.model) | ||
return { | ||
"model": { | ||
"name": "qwen3", | ||
"flavor": "1.7B", | ||
}, | ||
"checkpoint": { | ||
"enable": True, | ||
"folder": "/tmp/saved_checkpoints", | ||
"initial_load_path": cached_dir, | ||
"initial_load_in_hf": True, | ||
}, | ||
} | ||
|
||
@pytest_asyncio.fixture | ||
def expected_sd(self): | ||
model = AutoModelForCausalLM.from_pretrained( | ||
self.model, | ||
dtype=torch.bfloat16, | ||
trust_remote_code=True, | ||
) | ||
original_state_dict = model.state_dict() | ||
# Hack to access through class without passing in self param | ||
return self.__class__.to_vllm_fn(original_state_dict, self.num_layers) | ||
|
||
@pytest.mark.asyncio | ||
@requires_cuda | ||
async def test_policy_update_single(self, expected_sd, trainer_cfg): | ||
""" | ||
1. Loads weights from HF model into in-memory state-dict (source of truth) | ||
2. Initializes RLTrainer, make the weights available in torchstore. | ||
3. Initializes Policy, and calls update_weights() to load weights from torchstore. | ||
4. Validate the policy weights against source of truth. | ||
""" | ||
worker_size = 1 | ||
# 1. Initialize TS | ||
await ts.initialize() | ||
# 2. Trainer push | ||
rl_trainer = await spawn_service( | ||
ServiceConfig( | ||
procs_per_replica=worker_size, with_gpus=True, num_replicas=1 | ||
), | ||
RLTrainer, | ||
**trainer_cfg, | ||
) | ||
await rl_trainer.push_weights.choose(policy_version=0) | ||
# 3. Policy pull weights | ||
policy_config, service_config = get_configs( | ||
worker_size=worker_size, model_name=self.model | ||
) | ||
policy = await spawn_service(service_config, Policy, **policy_config) | ||
await policy.update_weights.call() | ||
# 4. Validate weights | ||
loaded_state_dict = await policy._get_model_params.choose() | ||
validate_loaded_tensors_equals_original( | ||
loaded_state_dict, expected_sd, tensor_parallel_size=1, rank=0 | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we starting at 1? Also, we probably want a todo to update this from the checkpoint
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because policy engine starting at 1. Lets keep this fragile contract as it is. The true version has to come from a config or external book-keeping entity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we can change this without risking breaking checkpoint expectations from titan side. I'd rather just use a separate variable in the trainer for "checkpoint name" (can be a property that's just current_step + 1 for now). This could also be passed in from the controller which would be better.