-
Notifications
You must be signed in to change notification settings - Fork 16
Load model from torchstore into vLLM #55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 34 commits
46153ea
a0fc785
9b6fa9f
f26d829
8f25f61
006f27e
10cce6b
1e4205c
d8de194
aa916eb
32f1683
a39444d
55c6a49
52bbf3b
082b138
44caf68
e69dbcd
08ba23e
c5dd764
4743217
dd36d73
ac6a212
b944a2e
8bb9710
8d029f5
a3355f5
6fed9b6
6e36dd3
a78be1b
300fe86
6003b12
ec07ba9
e0a1797
5af98a1
00c4a03
d0fb772
bdd2507
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,11 +12,13 @@ | |
|
||
import torch | ||
from monarch.actor import Actor, current_rank, endpoint, proc_mesh | ||
from torchstore import MultiProcessStore | ||
|
||
from torchstore._state_dict_utils import DELIM | ||
|
||
from vllm.engine.arg_utils import EngineArgs | ||
from vllm.entrypoints.utils import _validate_truncation_size | ||
from vllm.executor.multiproc_worker_utils import set_multiprocessing_worker_envs | ||
from vllm.inputs import TextPrompt, TokensPrompt | ||
from vllm.lora.request import LoRARequest | ||
from vllm.sampling_params import RequestOutputKind, SamplingParams | ||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs | ||
|
@@ -32,6 +34,8 @@ | |
from vllm.v1.structured_output import StructuredOutputManager | ||
from vllm.worker.worker_base import WorkerWrapperBase | ||
|
||
from forge.data.sharding import Llama3vLLMSharding | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
|
@@ -169,6 +173,8 @@ class Policy(Actor): | |
enforce_eager: bool = False | ||
vllm_args: EngineArgs = None | ||
resources: int = 1 | ||
torchstore: MultiProcessStore = None | ||
ankitageorge marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
state_dict_key: str = "model_state_dict" | ||
|
||
def __post_init__(self): | ||
"""Build vLLM Arguments | ||
|
@@ -222,10 +228,53 @@ async def setup(self): | |
async def execute_model(self, schedule: SchedulerOutput): | ||
return self.worker.execute_model(schedule) | ||
|
||
async def _load_tensor_parallel_state_dict(self, current_state_dict: dict): | ||
""" | ||
Load full state dict from torchstore into tensor parallel model with deterministic sharding. | ||
""" | ||
|
||
updated_count = 0 | ||
# setting explictly to llama3 for now as its our only use case | ||
sharding = Llama3vLLMSharding(self.tensor_parallel_size, self.rank) | ||
|
||
for param_name in current_state_dict.keys(): | ||
current_tensor = current_state_dict[param_name] | ||
|
||
# Load the full tensor from torchstore | ||
# TODO: only get the part of the tensor that is needed | ||
stored_tensor = await self.torchstore.get( | ||
ankitageorge marked this conversation as resolved.
Show resolved
Hide resolved
|
||
f"{self.state_dict_key}{DELIM}{param_name}" | ||
) | ||
sharding.load_from_source_to_target( | ||
param_name, | ||
stored_tensor, | ||
current_tensor, | ||
) | ||
|
||
updated_count += 1 | ||
|
||
logger.info(f"Successfully updated {updated_count} parameters") | ||
|
||
@endpoint | ||
async def update(self): | ||
# TODO: add TorchStore support | ||
pass | ||
"""Update model weights by reading state dict from torchstore""" | ||
|
||
if self.torchstore is None: | ||
raise Exception("No torchstore configured, skipping model update") | ||
|
||
logger.info( | ||
ankitageorge marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
f"Starting model update from torchstore with key: {self.state_dict_key}" | ||
) | ||
|
||
model = self.worker.model_runner.model | ||
current_state_dict = model.state_dict() | ||
|
||
logger.info(f"Current state dict has {len(current_state_dict)} parameters") | ||
logger.info(f"Tensor parallel size: {self.tensor_parallel_size}") | ||
|
||
await self._load_tensor_parallel_state_dict(current_state_dict) | ||
|
||
logger.info("Successfully updated model weights from torchstore") | ||
|
||
@endpoint | ||
async def setup_kv_cache(self): | ||
|
@@ -261,6 +310,17 @@ async def setup_kv_cache(self): | |
async def get_vllm_args(self): | ||
return self.vllm_args | ||
|
||
@endpoint | ||
async def get_model_params(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this function purely for testing, or we plan to leave it in for the final implementation? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in theory just for testing, but I think we need to leave it in, because I don't think there is another way to get the loaded params back from vllm to the test for comparison with the saved state dict There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this a method that can be patched into the actor class in the test? For example you can do a TestPolicy(Policy) and then add this method. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ya I can't seem to get this to work. I've tried what you suggested, and patching it in different ways, but nothing seems to work. |
||
model = self.worker.model_runner.model | ||
state_dict = {} | ||
|
||
for name, param in model.named_parameters(): | ||
if "layers.0" not in name: | ||
continue | ||
state_dict[name] = param.cpu().detach() | ||
return state_dict | ||
|
||
def setup_worker(self): | ||
"""Build and Instantiate vLLM worker""" | ||
parallel_config = self.vllm_args.parallel_config | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from abc import ABC, abstractmethod | ||
|
||
import torch | ||
|
||
|
||
class BaseSharding(ABC): | ||
""" | ||
Abstract base class for tensor parallel sharding strategies. | ||
""" | ||
|
||
def __init__(self, tensor_parallel_size: int, rank: int): | ||
self.tensor_parallel_size = tensor_parallel_size | ||
self.rank = rank | ||
|
||
@abstractmethod | ||
def load_from_source_to_target( | ||
self, | ||
param_name: str, | ||
source_tensor: torch.Tensor, | ||
target_tensor: torch.Tensor, | ||
) -> None: | ||
""" | ||
Copy a source tensor to a target tensor, handling sharding and replication. | ||
|
||
Args: | ||
param_name: Name of the parameter being loaded | ||
source_tensor: Source tensor to load from | ||
target_tensor: Target tensor to load into | ||
""" | ||
pass | ||
|
||
|
||
class Llama3vLLMSharding(BaseSharding): | ||
""" | ||
Llama3 vLLM specific tensor parallel sharding strategy. | ||
""" | ||
|
||
def load_from_source_to_target( | ||
self, | ||
param_name: str, | ||
source_tensor: torch.Tensor, | ||
target_tensor: torch.Tensor, | ||
) -> None: | ||
""" | ||
Copy a source tensor to a target tensor, handling sharding and replication. | ||
""" | ||
# Determine sharding strategy for this parameter | ||
shard_dim, is_sharded = self._get_tensor_parallel_sharding_strategy(param_name) | ||
|
||
if not is_sharded: | ||
# Parameter is replicated - shapes should match exactly | ||
if source_tensor.shape != target_tensor.shape: | ||
raise ValueError( | ||
f"Replicated parameter {param_name} has mismatched shapes: " | ||
f"{source_tensor.shape} vs {target_tensor.shape}, skipping" | ||
) | ||
|
||
# Direct copy for replicated parameters | ||
target_tensor.copy_(source_tensor) | ||
else: | ||
# Need to shard the full tensor | ||
sharded_tensor = self._calculate_tensor_shard( | ||
source_tensor, shard_dim, self.tensor_parallel_size, self.rank | ||
) | ||
|
||
if sharded_tensor.shape != target_tensor.shape: | ||
raise ValueError( | ||
f"Calculated shard for {param_name} has wrong shape: " | ||
f"{sharded_tensor.shape} vs expected {target_tensor.shape}, skipping" | ||
) | ||
|
||
target_tensor.copy_(sharded_tensor) | ||
|
||
def _get_tensor_parallel_sharding_strategy( | ||
self, param_name: str | ||
) -> tuple[int, bool]: | ||
""" | ||
Determine the sharding strategy for a parameter in tensor parallel setup. | ||
|
||
Returns: | ||
tuple[int, bool]: (shard_dimension, is_sharded) | ||
- shard_dimension: Which dimension to shard (0 or 1) | ||
- is_sharded: Whether this parameter should be sharded at all | ||
|
||
Based on vLLM's tensor parallel implementation for LLaMA models: | ||
- Embedding layers: shard along vocab dimension (dim 0) | ||
- Attention projections: qkv_proj shard along hidden dimension (dim 0), o_proj along input dimension (dim 1) | ||
- MLP projections: gate/up_proj shard along hidden dimension (dim 0), down_proj along input dimension (dim 1) | ||
- Layer norms: not sharded (replicated) | ||
- Output layer: shard along vocab dimension (dim 0) | ||
""" | ||
# Parameters that are not sharded (replicated across all tensor parallel ranks) | ||
if any(keyword in param_name for keyword in ["norm", "bias", "rotary_emb"]): | ||
return 0, False | ||
|
||
# Embedding layers - shard along vocab dimension (dim 0) | ||
if "embed_tokens" in param_name or "lm_head" in param_name: | ||
return 0, True | ||
|
||
# Attention projections | ||
if "qkv_proj" in param_name: | ||
# Input projections: shard output dimension (dim 0) | ||
return 0, True | ||
elif "o_proj" in param_name: | ||
# Output projection: shard input dimension (dim 1) | ||
return 1, True | ||
|
||
# MLP projections | ||
elif any( | ||
proj in param_name for proj in ["gate_proj", "up_proj", "gate_up_proj"] | ||
): | ||
# Input projections: shard output dimension (dim 0) | ||
return 0, True | ||
elif "down_proj" in param_name: | ||
# Output projection: shard input dimension (dim 1) | ||
return 1, True | ||
|
||
# Default: try to infer from tensor shape patterns | ||
return 0, True | ||
|
||
def _calculate_tensor_shard( | ||
self, | ||
full_tensor: torch.Tensor, | ||
shard_dim: int, | ||
tensor_parallel_size: int, | ||
rank: int, | ||
) -> torch.Tensor: | ||
""" | ||
Calculate the shard of a full tensor for the current tensor parallel rank. | ||
|
||
Args: | ||
full_tensor: The full tensor to shard | ||
shard_dim: Which dimension to shard along (0 or 1) | ||
tensor_parallel_size: Number of tensor parallel ranks | ||
rank: Current rank (will be modulo by tensor_parallel_size) | ||
|
||
Returns: | ||
torch.Tensor: The sharded tensor for this rank | ||
""" | ||
tp_rank = rank % tensor_parallel_size | ||
tensor_size = full_tensor.shape[shard_dim] | ||
|
||
if tensor_size % tensor_parallel_size != 0: | ||
raise ValueError( | ||
f"Cannot shard tensor dimension {shard_dim} with size {tensor_size} " | ||
f"across {tensor_parallel_size} ranks: not evenly divisible" | ||
) | ||
|
||
shard_size = tensor_size // tensor_parallel_size | ||
start_idx = tp_rank * shard_size | ||
end_idx = start_idx + shard_size | ||
|
||
# Create index tensor for the shard range | ||
indices = torch.arange(start_idx, end_idx, device=full_tensor.device) | ||
|
||
if shard_dim == 0: | ||
return torch.index_select(full_tensor, 0, indices) | ||
elif shard_dim == 1: | ||
return torch.index_select(full_tensor, 1, indices) | ||
else: | ||
raise ValueError(f"Unsupported shard dimension: {shard_dim}") |
Uh oh!
There was an error while loading. Please reload this page.