-
Notifications
You must be signed in to change notification settings - Fork 16
Load model from torchstore into vLLM #55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 21 commits
46153ea
a0fc785
9b6fa9f
f26d829
8f25f61
006f27e
10cce6b
1e4205c
d8de194
aa916eb
32f1683
a39444d
55c6a49
52bbf3b
082b138
44caf68
e69dbcd
08ba23e
c5dd764
4743217
dd36d73
ac6a212
b944a2e
8bb9710
8d029f5
a3355f5
6fed9b6
6e36dd3
a78be1b
300fe86
6003b12
ec07ba9
e0a1797
5af98a1
00c4a03
d0fb772
bdd2507
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,9 @@ | |
|
||
import torch | ||
from monarch.actor import Actor, current_rank, endpoint, proc_mesh | ||
from torchstore import MultiProcessStore | ||
|
||
from torchstore._state_dict_utils import DELIM, get_state_dict, MAPPING | ||
|
||
from vllm.engine.arg_utils import EngineArgs | ||
from vllm.entrypoints.utils import _validate_truncation_size | ||
|
@@ -169,6 +172,8 @@ class Policy(Actor): | |
enforce_eager: bool = False | ||
vllm_args: EngineArgs = None | ||
resources: int = 1 | ||
torchstore: MultiProcessStore = None | ||
state_dict_key: str = "model_state_dict" | ||
|
||
def __post_init__(self): | ||
"""Build vLLM Arguments | ||
|
@@ -190,6 +195,7 @@ def __post_init__(self): | |
tensor_parallel_size=self.tensor_parallel_size, | ||
pipeline_parallel_size=self.pipeline_parallel_size, | ||
enforce_eager=self.enforce_eager, | ||
gpu_memory_utilization=0.4, | ||
ankitageorge marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
) | ||
# Original method returns False when not run in the main thread | ||
self.vllm_args._is_v1_supported_oracle = lambda *_: True | ||
|
@@ -222,10 +228,157 @@ async def setup(self): | |
async def execute_model(self, schedule: SchedulerOutput): | ||
return self.worker.execute_model(schedule) | ||
|
||
def _get_tensor_parallel_sharding_strategy( | ||
self, param_name: str | ||
) -> tuple[int, bool]: | ||
""" | ||
Determine the sharding strategy for a parameter in tensor parallel setup. | ||
|
||
Returns: | ||
tuple[int, bool]: (shard_dimension, is_sharded) | ||
- shard_dimension: Which dimension to shard (0 or 1) | ||
- is_sharded: Whether this parameter should be sharded at all | ||
|
||
Based on vLLM's tensor parallel implementation for LLaMA models: | ||
- Embedding layers: shard along vocab dimension (dim 0) | ||
- Attention projections: qk/_proj shard along hidden dimension (dim 0), o_proj along input dimension (dim 1) | ||
- MLP projections: gate/up_proj shard along hidden dimension (dim 0), down_proj along input dimension (dim 1) | ||
- Layer norms: not sharded (replicated) | ||
- Output layer: shard along vocab dimension (dim 0) | ||
""" | ||
# Parameters that are not sharded (replicated across all tensor parallel ranks) | ||
if any(keyword in param_name for keyword in ["norm", "bias", "rotary_emb"]): | ||
|
||
return 0, False | ||
|
||
# Embedding layers - shard along vocab dimension (dim 0) | ||
if "embed_tokens" in param_name or "lm_head" in param_name: | ||
return 0, True | ||
|
||
# Attention projections | ||
if "qkv_proj" in param_name: | ||
# Input projections: shard output dimension (dim 0) | ||
return 0, True | ||
elif "o_proj" in param_name: | ||
# Output projection: shard input dimension (dim 1) | ||
return 1, True | ||
|
||
# MLP projections | ||
elif any(proj in param_name for proj in ["gate_proj", "up_proj"]): | ||
# Input projections: shard output dimension (dim 0) | ||
return 0, True | ||
elif "down_proj" in param_name: | ||
# Output projection: shard input dimension (dim 1) | ||
return 1, True | ||
|
||
# Default: try to infer from tensor shape patterns | ||
return 0, True | ||
|
||
def _calculate_tensor_shard( | ||
self, full_tensor: torch.Tensor, shard_dim: int | ||
) -> torch.Tensor: | ||
""" | ||
Calculate the shard of a full tensor for the current tensor parallel rank. | ||
|
||
Args: | ||
full_tensor: The full tensor to shard | ||
shard_dim: Which dimension to shard along (0 or 1) | ||
|
||
Returns: | ||
torch.Tensor: The sharded tensor for this rank | ||
""" | ||
tp_rank = self.rank % self.tensor_parallel_size | ||
|
||
tensor_size = full_tensor.shape[shard_dim] | ||
|
||
if tensor_size % self.tensor_parallel_size != 0: | ||
raise ValueError( | ||
f"Cannot shard tensor dimension {shard_dim} with size {tensor_size} " | ||
f"across {self.tensor_parallel_size} ranks: not evenly divisible" | ||
) | ||
|
||
shard_size = tensor_size // self.tensor_parallel_size | ||
start_idx = tp_rank * shard_size | ||
end_idx = start_idx + shard_size | ||
|
||
if shard_dim == 0: | ||
return full_tensor[start_idx:end_idx] | ||
elif shard_dim == 1: | ||
return full_tensor[:, start_idx:end_idx] | ||
else: | ||
raise ValueError(f"Unsupported shard dimension: {shard_dim}") | ||
ankitageorge marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
async def _load_tensor_parallel_state_dict(self, current_state_dict: dict): | ||
""" | ||
Load full state dict from torchstore into tensor parallel model with deterministic sharding. | ||
""" | ||
|
||
updated_count = 0 | ||
|
||
for param_name in current_state_dict.keys(): | ||
current_tensor = current_state_dict[param_name] | ||
|
||
# Load the full tensor from torchstore | ||
stored_tensor = await self.torchstore.get( | ||
ankitageorge marked this conversation as resolved.
Show resolved
Hide resolved
|
||
f"{self.state_dict_key}{DELIM}{param_name}" | ||
) | ||
|
||
# Determine sharding strategy for this parameter | ||
shard_dim, is_sharded = self._get_tensor_parallel_sharding_strategy( | ||
param_name | ||
) | ||
|
||
if not is_sharded: | ||
# Parameter is replicated - shapes should match exactly | ||
if stored_tensor.shape != current_tensor.shape: | ||
raise ValueError( | ||
f"Replicated parameter {param_name} has mismatched shapes: " | ||
f"{stored_tensor.shape} vs {current_tensor.shape}, skipping" | ||
) | ||
|
||
# Direct copy for replicated parameters | ||
current_state_dict[param_name].copy_(stored_tensor) | ||
|
||
else: | ||
# Need to shard the full tensor | ||
sharded_tensor = self._calculate_tensor_shard(stored_tensor, shard_dim) | ||
|
||
if sharded_tensor.shape != current_tensor.shape: | ||
raise ValueError( | ||
f"Calculated shard for {param_name} has wrong shape: " | ||
f"{sharded_tensor.shape} vs expected {current_tensor.shape}, skipping" | ||
) | ||
|
||
current_state_dict[param_name].copy_(sharded_tensor) | ||
|
||
updated_count += 1 | ||
|
||
logger.info(f"Successfully updated {updated_count} parameters") | ||
|
||
@endpoint | ||
async def update(self): | ||
# TODO: add TorchStore support | ||
pass | ||
"""Update model weights by reading state dict from torchstore""" | ||
|
||
if self.torchstore is None: | ||
raise Exception("No torchstore configured, skipping model update") | ||
|
||
logger.info( | ||
ankitageorge marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
f"Starting model update from torchstore with key: {self.state_dict_key}" | ||
) | ||
|
||
# Get the current model from the worker | ||
model = self.worker.model_runner.model | ||
current_state_dict = model.state_dict() | ||
|
||
logger.info(f"Current state dict has {len(current_state_dict)} parameters") | ||
logger.info(f"Tensor parallel size: {self.tensor_parallel_size}") | ||
|
||
# Tensor parallel model - use deterministic sharding strategy | ||
logger.info("Loading state dict with tensor parallel sharding...") | ||
await self._load_tensor_parallel_state_dict(current_state_dict) | ||
|
||
# Load the updated state dict into the model | ||
model.load_state_dict(current_state_dict, strict=True) | ||
|
||
logger.info("Successfully updated model weights from torchstore") | ||
|
||
@endpoint | ||
async def setup_kv_cache(self): | ||
|
@@ -261,27 +414,67 @@ async def setup_kv_cache(self): | |
async def get_vllm_args(self): | ||
return self.vllm_args | ||
|
||
@endpoint | ||
async def get_model_params(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this function purely for testing, or we plan to leave it in for the final implementation? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in theory just for testing, but I think we need to leave it in, because I don't think there is another way to get the loaded params back from vllm to the test for comparison with the saved state dict There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this a method that can be patched into the actor class in the test? For example you can do a TestPolicy(Policy) and then add this method. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ya I can't seem to get this to work. I've tried what you suggested, and patching it in different ways, but nothing seems to work. |
||
model = self.worker.model_runner.model | ||
state_dict = {} | ||
|
||
for name, param in model.named_parameters(): | ||
# only use one layer for testing, otherwise it's too slow | ||
if "layers.0" in name: | ||
state_dict[name] = param.cpu().detach() | ||
return state_dict | ||
|
||
def setup_worker(self): | ||
"""Build and Instantiate vLLM worker""" | ||
parallel_config = self.vllm_args.parallel_config | ||
set_multiprocessing_worker_envs(parallel_config) | ||
|
||
# Get distributed init info | ||
ip, port = os.getenv("MASTER_ADDR"), os.getenv("MASTER_PORT") | ||
distributed_init_method = get_distributed_init_method(ip, port) | ||
all_kwargs = [{}] * parallel_config.world_size | ||
local_rank = self.rank % torch.accelerator.device_count() | ||
|
||
# Calculate local rank properly | ||
device_count = torch.cuda.device_count() if torch.cuda.is_available() else 1 | ||
local_rank = self.rank % device_count | ||
|
||
# Validate local rank | ||
if local_rank >= device_count: | ||
raise ValueError( | ||
f"Local rank {local_rank} exceeds available devices {device_count}" | ||
) | ||
ankitageorge marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
# Calculate driver worker properly | ||
is_driver_worker = self.rank % parallel_config.tensor_parallel_size == 0 | ||
|
||
# Prepare worker kwargs | ||
all_kwargs = [{}] * parallel_config.world_size | ||
all_kwargs[self.rank] = { | ||
"vllm_config": self.vllm_args, | ||
"local_rank": local_rank, | ||
"rank": self.rank, | ||
"distributed_init_method": distributed_init_method, | ||
"is_driver_worker": is_driver_worker, | ||
} | ||
worker = WorkerWrapperBase(self.vllm_args, self.rank) | ||
worker.init_worker(all_kwargs) | ||
worker.init_device() | ||
worker.load_model() | ||
return worker | ||
|
||
logger.info( | ||
f"Setting up worker: rank={self.rank}, local_rank={local_rank}, " | ||
f"is_driver={is_driver_worker}, device_count={device_count}" | ||
) | ||
|
||
try: | ||
worker = WorkerWrapperBase(self.vllm_args, self.rank) | ||
worker.init_worker(all_kwargs) | ||
worker.init_device() | ||
worker.load_model() | ||
return worker | ||
except Exception as e: | ||
logger.error(f"Failed to setup worker: {e}") | ||
logger.error( | ||
f"Worker config: rank={self.rank}, local_rank={local_rank}, " | ||
f"device_count={device_count}, world_size={parallel_config.world_size}" | ||
) | ||
raise | ||
|
||
|
||
def convert_input(prompt=None, prompt_token_ids=None): | ||
|
Uh oh!
There was an error while loading. Please reload this page.