diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 1a3076d52..e735ef6e6 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -12,11 +12,13 @@ import torch from monarch.actor import Actor, current_rank, endpoint, proc_mesh +from torchstore import MultiProcessStore + +from torchstore._state_dict_utils import DELIM from vllm.engine.arg_utils import EngineArgs from vllm.entrypoints.utils import _validate_truncation_size from vllm.executor.multiproc_worker_utils import set_multiprocessing_worker_envs -from vllm.inputs import TextPrompt, TokensPrompt from vllm.lora.request import LoRARequest from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs @@ -32,6 +34,8 @@ from vllm.v1.structured_output import StructuredOutputManager from vllm.worker.worker_base import WorkerWrapperBase +from forge.data.sharding import VLLMSharding + logger = logging.getLogger(__name__) @@ -169,6 +173,7 @@ class Policy(Actor): enforce_eager: bool = False vllm_args: EngineArgs = None resources: int = 1 + state_dict_key: str = "model_state_dict" def __post_init__(self): """Build vLLM Arguments @@ -213,7 +218,8 @@ def __post_init__(self): assert self.vllm_args.parallel_config.world_size == self.resources @endpoint - async def setup(self): + async def setup(self, store: MultiProcessStore = None): + self.torchstore = store # TODO: remove ["gpus"] when monarch implements a flat rank self.rank = current_rank()["gpus"] self.worker = self.setup_worker() @@ -222,10 +228,50 @@ async def setup(self): async def execute_model(self, schedule: SchedulerOutput): return self.worker.execute_model(schedule) + async def _load_tensor_parallel_state_dict(self, current_state_dict: dict): + """ + Load full state dict from torchstore into tensor parallel model with deterministic sharding. + """ + + updated_count = 0 + # setting explictly to llama3 for now as its our only use case + sharding = VLLMSharding(self.tensor_parallel_size, self.rank) + + for param_name in current_state_dict.keys(): + current_tensor = current_state_dict[param_name] + + # Load the full tensor from torchstore + # TODO: only get the part of the tensor that is needed + stored_tensor = await self.torchstore.get( + f"{self.state_dict_key}{DELIM}{param_name}" + ) + sharding.load_from_source_to_target( + param_name, + stored_tensor, + current_tensor, + ) + + updated_count += 1 + @endpoint async def update(self): - # TODO: add TorchStore support - pass + """Update model weights by reading state dict from torchstore""" + + if self.torchstore is None: + raise Exception("No torchstore configured, skipping model update") + + logger.debug( + f"Starting model update from torchstore with key: {self.state_dict_key}" + ) + + model = self.worker.model_runner.model + current_state_dict = model.state_dict() + + logger.debug(f"Current state dict has {len(current_state_dict)} parameters") + + await self._load_tensor_parallel_state_dict(current_state_dict) + + logger.debug("Successfully updated model weights from torchstore") @endpoint async def setup_kv_cache(self): @@ -261,6 +307,17 @@ async def setup_kv_cache(self): async def get_vllm_args(self): return self.vllm_args + @endpoint + async def get_model_params(self): + model = self.worker.model_runner.model + state_dict = {} + + for name, param in model.named_parameters(): + if "layers.0" not in name: + continue + state_dict[name] = param.cpu().detach() + return state_dict + def setup_worker(self): """Build and Instantiate vLLM worker""" parallel_config = self.vllm_args.parallel_config diff --git a/src/forge/data/sharding.py b/src/forge/data/sharding.py new file mode 100644 index 000000000..2027f8a43 --- /dev/null +++ b/src/forge/data/sharding.py @@ -0,0 +1,142 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch + + +class VLLMSharding: + """ + vLLM specific tensor parallel sharding strategy. + """ + + def __init__(self, tensor_parallel_size: int, rank: int): + self.tensor_parallel_size = tensor_parallel_size + self.rank = rank + + def load_from_source_to_target( + self, + param_name: str, + source_tensor: torch.Tensor, + target_tensor: torch.Tensor, + ) -> None: + """ + Copy a source tensor to a target tensor, handling sharding and replication. + """ + # Determine sharding strategy for this parameter + shard_dim, is_sharded = self._get_tensor_parallel_sharding_strategy(param_name) + + if not is_sharded: + # Parameter is replicated - shapes should match exactly + if source_tensor.shape != target_tensor.shape: + raise ValueError( + f"Replicated parameter {param_name} has mismatched shapes: " + f"{source_tensor.shape} vs {target_tensor.shape}, skipping" + ) + + # Direct copy for replicated parameters + target_tensor.copy_(source_tensor) + else: + # Need to shard the full tensor + sharded_tensor = self._calculate_tensor_shard( + source_tensor, shard_dim, self.tensor_parallel_size, self.rank + ) + + if sharded_tensor.shape != target_tensor.shape: + raise ValueError( + f"Calculated shard for {param_name} has wrong shape: " + f"{sharded_tensor.shape} vs expected {target_tensor.shape}, skipping" + ) + + target_tensor.copy_(sharded_tensor) + + def _get_tensor_parallel_sharding_strategy( + self, param_name: str + ) -> tuple[int, bool]: + """ + Determine the sharding strategy for a parameter in tensor parallel setup. + + Returns: + tuple[int, bool]: (shard_dimension, is_sharded) + - shard_dimension: Which dimension to shard (0 or 1) + - is_sharded: Whether this parameter should be sharded at all + + Based on vLLM's tensor parallel implementation for LLaMA models: + - Embedding layers: shard along vocab dimension (dim 0) + - Attention projections: qkv_proj shard along hidden dimension (dim 0), o_proj along input dimension (dim 1) + - MLP projections: gate/up_proj shard along hidden dimension (dim 0), down_proj along input dimension (dim 1) + - Layer norms: not sharded (replicated) + - Output layer: shard along vocab dimension (dim 0) + """ + # Parameters that are not sharded (replicated across all tensor parallel ranks) + if any(keyword in param_name for keyword in ["norm", "bias", "rotary_emb"]): + return 0, False + + # Embedding layers - shard along vocab dimension (dim 0) + if "embed_tokens" in param_name or "lm_head" in param_name: + return 0, True + + # Attention projections + if "qkv_proj" in param_name: + # Input projections: shard output dimension (dim 0) + return 0, True + elif "o_proj" in param_name: + # Output projection: shard input dimension (dim 1) + return 1, True + + # MLP projections + elif any( + proj in param_name for proj in ["gate_proj", "up_proj", "gate_up_proj"] + ): + # Input projections: shard output dimension (dim 0) + return 0, True + elif "down_proj" in param_name: + # Output projection: shard input dimension (dim 1) + return 1, True + + # Default: try to infer from tensor shape patterns + return 0, True + + def _calculate_tensor_shard( + self, + full_tensor: torch.Tensor, + shard_dim: int, + tensor_parallel_size: int, + rank: int, + ) -> torch.Tensor: + """ + Calculate the shard of a full tensor for the current tensor parallel rank. + + Args: + full_tensor: The full tensor to shard + shard_dim: Which dimension to shard along (0 or 1) + tensor_parallel_size: Number of tensor parallel ranks + rank: Current rank (will be modulo by tensor_parallel_size) + + Returns: + torch.Tensor: The sharded tensor for this rank + """ + tp_rank = rank % tensor_parallel_size + tensor_size = full_tensor.shape[shard_dim] + + if tensor_size % tensor_parallel_size != 0: + raise ValueError( + f"Cannot shard tensor dimension {shard_dim} with size {tensor_size} " + f"across {tensor_parallel_size} ranks: not evenly divisible" + ) + + shard_size = tensor_size // tensor_parallel_size + start_idx = tp_rank * shard_size + end_idx = start_idx + shard_size + + # Create index tensor for the shard range + indices = torch.arange(start_idx, end_idx, device=full_tensor.device) + + if shard_dim == 0: + return torch.index_select(full_tensor, 0, indices) + elif shard_dim == 1: + return torch.index_select(full_tensor, 1, indices) + else: + raise ValueError(f"Unsupported shard dimension: {shard_dim}") diff --git a/tests/integration_tests/test_policy_update.py b/tests/integration_tests/test_policy_update.py new file mode 100644 index 000000000..733abcd21 --- /dev/null +++ b/tests/integration_tests/test_policy_update.py @@ -0,0 +1,310 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os + +import pytest +import pytest_asyncio + +import torch + +from forge.actors.policy import Policy +from forge.data.sharding import VLLMSharding +from monarch.actor import proc_mesh +from torchstore import MultiProcessStore +from torchstore._state_dict_utils import push_state_dict +from transformers import AutoModelForCausalLM + +from vllm.utils import get_open_port + +requires_cuda = pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA not available", +) + + +def convert_state_dict(saved_sd): + """ + Convert transformers state dict to vLLM format. + + Key conversions: + 1. Copy over directly mapped keys (down_proj, input_layernorm, etc.) + 2. Fuse QKV projections: combine q_proj, k_proj, v_proj into qkv_proj + 3. Fuse MLP projections: combine gate_proj and up_proj into gate_up_proj + """ + load_sd = {} + num_layers = 32 # For Llama-8B-3.1 + + # Copy over directly mapped keys + for k in saved_sd: + if any( + x in k + for x in [ + "down_proj", + "input_layernorm", + "post_attention_layernorm", + "o_proj", + "norm.weight", + "embed_tokens.weight", + "lm_head.weight", + ] + ): + load_sd[k] = saved_sd[k] + + # Fuse QKV and gate_up_proj + for i in range(num_layers): + prefix = f"model.layers.{i}." + + # QKV fusion + q = saved_sd[prefix + "self_attn.q_proj.weight"] + k = saved_sd[prefix + "self_attn.k_proj.weight"] + v = saved_sd[prefix + "self_attn.v_proj.weight"] + load_sd[prefix + "self_attn.qkv_proj.weight"] = torch.cat([q, k, v], dim=0) + + # MLP gate_up_proj fusion + gate = saved_sd[prefix + "mlp.gate_proj.weight"] + up = saved_sd[prefix + "mlp.up_proj.weight"] + load_sd[prefix + "mlp.gate_up_proj.weight"] = torch.cat([gate, up], dim=0) + + return load_sd + + +async def save_state_dict( + store: MultiProcessStore, state_dict: dict[str, torch.Tensor], key_prefix: str +): + print(f"Saving {len(state_dict)} tensors") + + await push_state_dict(store, state_dict, key_prefix) + + print(f"Successfully saved {len(state_dict)} tensors") + + +def calculate_expected_shard( + full_tensor: torch.Tensor, + param_name: str, + expected_shape: torch.Size, + tensor_parallel_size: int, + rank: int, +) -> torch.Tensor: + """ + Calculate the expected shard of a full tensor for comparison with loaded tensor. + This is mainly used for validation in tests. + + Args: + full_tensor: The full tensor to shard + param_name: Name of the parameter (used to determine sharding strategy) + expected_shape: Expected shape of the sharded tensor + tensor_parallel_size: Number of tensor parallel ranks + rank: Current rank + + Returns: + torch.Tensor: The expected sharded tensor for this rank + """ + + sharding = VLLMSharding(tensor_parallel_size, rank) + shard_dim, is_sharded = sharding._get_tensor_parallel_sharding_strategy(param_name) + + if not is_sharded: + return full_tensor + + sharded_tensor = sharding._calculate_tensor_shard( + full_tensor, shard_dim, tensor_parallel_size, rank + ) + return sharded_tensor + + +def validate_loaded_tensors_equals_original( + loaded_state_dict: dict[str, torch.Tensor], + original_state_dict: dict[str, torch.Tensor], + tensor_parallel_size: int, + rank: int, +): + """ + Shared validation function to verify that every tensor loaded by the policy + equals the original tensor. + + For tensor parallel cases, instead of gathering sharded tensors, we shard + the original tensor and compare it with the loaded shard. + """ + validation_errors = [] + + for param_name, loaded_tensor in loaded_state_dict.items(): + if param_name in original_state_dict: + expected_tensor = original_state_dict[param_name] + + if tensor_parallel_size > 1: + expected_shard = calculate_expected_shard( + expected_tensor, + param_name, + loaded_tensor.shape, + tensor_parallel_size, + rank, + ) + tensor_to_compare = expected_shard.cpu().float() + else: + tensor_to_compare = expected_tensor.cpu().float() + + if not torch.allclose( + loaded_tensor.float(), + tensor_to_compare, + rtol=1e-5, + atol=1e-8, + ): + validation_errors.append( + f"Loaded tensor {param_name} does not equal original " + f"(shapes: loaded={loaded_tensor.shape}, expected={tensor_to_compare.shape})" + ) + else: + print(f"Loaded tensor {param_name} correctly validated") + + if validation_errors: + raise ValueError(f"Validation failed: {validation_errors}") + + print( + f"Successfully validated that all {len(loaded_state_dict)} loaded tensors equal original" + ) + + +async def run_policy_integration(store, original_state_dict, num_gpus): + """ + Common helper function to test Policy integration with different GPU configurations. + + Args: + store: TorchStore instance + original_state_dict: Original state dict for validation + num_gpus: Number of GPUs to use (1 for single GPU, 2+ for tensor parallel) + test_name: Name for test identification in validation messages + """ + print(f"=== PHASE 2: Testing Policy Integration (GPUs: {num_gpus}) ===") + + state_dict_key = "llama3_8b_state_dict" + + # Set up environment variables for vLLM distributed initialization + if num_gpus == 1: + # Single GPU setup + os.environ.setdefault("MASTER_ADDR", "localhost") + os.environ.setdefault("MASTER_PORT", "12355") + os.environ.setdefault("RANK", "0") + os.environ.setdefault("WORLD_SIZE", "1") + master_addr = os.environ.get("MASTER_ADDR", "localhost") + master_port = os.environ.get("MASTER_PORT", "12355") + else: + # Multi-GPU setup + master_addr = "localhost" + master_port = str(get_open_port()) + os.environ["MASTER_ADDR"] = master_addr + os.environ["MASTER_PORT"] = master_port + print(f"Using MASTER_PORT: {master_port} for tensor parallel Policy") + + rank = int(os.environ.get("RANK", "0")) + + policy_mesh = await proc_mesh( + gpus=num_gpus, + env={ + "MASTER_ADDR": master_addr, + "MASTER_PORT": master_port, + }, + ) + + # Spawn Policy as a proper Monarch actor + policy = await policy_mesh.spawn( + "policy", + Policy, + model="meta-llama/Meta-Llama-3.1-8B-Instruct", + tensor_parallel_size=num_gpus, + pipeline_parallel_size=1, + enforce_eager=True, + resources=num_gpus, + state_dict_key=state_dict_key, + ) + + await policy.setup.call(store) + print("Setup completed successfully!") + + print("Calling Policy.update() to load weights from torchstore...") + await policy.update.call() + print("Successfully called Policy.update() to load weights from torchstore!") + + model_params = await policy.get_model_params.call() + loaded_state_dict = ( + model_params._values[0] if hasattr(model_params, "_values") else model_params + ) + print("Successfully got model state dict after update") + + validate_loaded_tensors_equals_original( + loaded_state_dict, original_state_dict, tensor_parallel_size=num_gpus, rank=rank + ) + + print("Test passed! State dict successfully loaded into Policy!") + + +@pytest_asyncio.fixture(scope="session") +async def llama3_torchstore_setup(): + """ + Pytest fixture to load Llama 3.1 8B-Instruct and write state dict to torchstore. + Uses session scope so it's only called once when both tests are run. + """ + print("=== PHASE 1: Writing Llama 3.1 8B-Instruct to TorchStore ===") + + store = await MultiProcessStore.create_store() + + model_path = "meta-llama/Meta-Llama-3.1-8B-Instruct" + + # Load the model from local path - using device_map="auto" for efficient loading + model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=torch.float16, # Use half precision to save memory + device_map="auto", + trust_remote_code=True, + ) + + original_state_dict = model.state_dict() + print(f"Original state dict has {len(original_state_dict)} parameters") + + print("Converting transformers state dict to vLLM format...") + converted_state_dict = convert_state_dict(original_state_dict) + print(f"Converted state dict has {len(converted_state_dict)} parameters") + + state_dict_key = "llama3_8b_state_dict" + await save_state_dict(store, converted_state_dict, state_dict_key) + print( + f"Successfully wrote converted state dict to torchstore with key: {state_dict_key}" + ) + + return store, converted_state_dict + + +@pytest.mark.asyncio +@requires_cuda +async def test_llama3_policy_update_single(llama3_torchstore_setup): + print("Starting Llama 3 8B torchstore test (single GPU)...") + + store, original_state_dict = llama3_torchstore_setup + + await run_policy_integration(store, original_state_dict, num_gpus=1) + + print( + "Single GPU test passed! Llama 3.1 8B-Instruct model successfully loaded into Policy via TorchStore!" + ) + + +@pytest.mark.asyncio +@requires_cuda +async def test_llama3_policy_update_tp(llama3_torchstore_setup): + print("Starting tensor parallel test (load full state dict into sharded model)...") + + if torch.cuda.device_count() < 2: + pytest.skip( + f"Only {torch.cuda.device_count()} GPU(s) available, need 2+ for tensor parallel" + ) + + store, original_state_dict = llama3_torchstore_setup + + await run_policy_integration(store, original_state_dict, num_gpus=2) + + print( + "Tensor parallel test passed! Full state dict successfully loaded into tensor parallel model!" + )