Skip to content
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
46153ea
initial testing
Aug 13, 2025
a0fc785
more testing
Aug 14, 2025
9b6fa9f
init works
Aug 14, 2025
f26d829
use instruct in path
Aug 14, 2025
8f25f61
somewhat working
Aug 14, 2025
006f27e
kinda working, memory/timeout issue
Aug 14, 2025
10cce6b
store and load working!
Aug 14, 2025
1e4205c
clean up logging
Aug 15, 2025
d8de194
sharded working
Aug 15, 2025
aa916eb
it's working? but _get_tensor_parallel_sharding_strategy is hacky:
Aug 15, 2025
32f1683
it's working
Aug 15, 2025
a39444d
some cleanups
Aug 18, 2025
55c6a49
more clean up
Aug 18, 2025
52bbf3b
clean ups
Aug 18, 2025
082b138
get rid of if else logic
Aug 18, 2025
44caf68
mapping
Aug 18, 2025
e69dbcd
mostly working
Aug 19, 2025
08ba23e
mostly working 2
Aug 19, 2025
c5dd764
mostly working 3
Aug 19, 2025
4743217
single test passes
Aug 19, 2025
dd36d73
single and fsdp works with calculated sharding
Aug 20, 2025
ac6a212
convert from script to test
Aug 20, 2025
b944a2e
cleaning things up
Aug 20, 2025
8bb9710
more cleaning up
Aug 20, 2025
8d029f5
move sharding to helper
Aug 20, 2025
a3355f5
move sharding to helper 2
Aug 20, 2025
6fed9b6
refactor
Aug 20, 2025
6e36dd3
use sharding class in policy and test
Aug 20, 2025
a78be1b
renames
Aug 20, 2025
300fe86
use test fixture
Aug 20, 2025
6003b12
use helper in test
Aug 20, 2025
ec07ba9
remove extra comments
Aug 20, 2025
e0a1797
remove extra load
Aug 20, 2025
5af98a1
clean up prints
Aug 20, 2025
00c4a03
requested changes
Aug 21, 2025
d0fb772
requested changes 2
Aug 21, 2025
bdd2507
use remote dir
Aug 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 63 additions & 3 deletions src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@

import torch
from monarch.actor import Actor, current_rank, endpoint, proc_mesh
from torchstore import MultiProcessStore

from torchstore._state_dict_utils import DELIM

from vllm.engine.arg_utils import EngineArgs
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.executor.multiproc_worker_utils import set_multiprocessing_worker_envs
from vllm.inputs import TextPrompt, TokensPrompt
from vllm.lora.request import LoRARequest
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
Expand All @@ -32,6 +34,8 @@
from vllm.v1.structured_output import StructuredOutputManager
from vllm.worker.worker_base import WorkerWrapperBase

from forge.data.sharding import Llama3vLLMSharding

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -169,6 +173,8 @@ class Policy(Actor):
enforce_eager: bool = False
vllm_args: EngineArgs = None
resources: int = 1
torchstore: MultiProcessStore = None
state_dict_key: str = "model_state_dict"

def __post_init__(self):
"""Build vLLM Arguments
Expand Down Expand Up @@ -222,10 +228,53 @@ async def setup(self):
async def execute_model(self, schedule: SchedulerOutput):
return self.worker.execute_model(schedule)

async def _load_tensor_parallel_state_dict(self, current_state_dict: dict):
"""
Load full state dict from torchstore into tensor parallel model with deterministic sharding.
"""

updated_count = 0
# setting explictly to llama3 for now as its our only use case
sharding = Llama3vLLMSharding(self.tensor_parallel_size, self.rank)

for param_name in current_state_dict.keys():
current_tensor = current_state_dict[param_name]

# Load the full tensor from torchstore
# TODO: only get the part of the tensor that is needed
stored_tensor = await self.torchstore.get(
f"{self.state_dict_key}{DELIM}{param_name}"
)
sharding.load_from_source_to_target(
param_name,
stored_tensor,
current_tensor,
)

updated_count += 1

logger.info(f"Successfully updated {updated_count} parameters")

@endpoint
async def update(self):
# TODO: add TorchStore support
pass
"""Update model weights by reading state dict from torchstore"""

if self.torchstore is None:
raise Exception("No torchstore configured, skipping model update")

logger.info(
f"Starting model update from torchstore with key: {self.state_dict_key}"
)

model = self.worker.model_runner.model
current_state_dict = model.state_dict()

logger.info(f"Current state dict has {len(current_state_dict)} parameters")
logger.info(f"Tensor parallel size: {self.tensor_parallel_size}")

await self._load_tensor_parallel_state_dict(current_state_dict)

logger.info("Successfully updated model weights from torchstore")

@endpoint
async def setup_kv_cache(self):
Expand Down Expand Up @@ -261,6 +310,17 @@ async def setup_kv_cache(self):
async def get_vllm_args(self):
return self.vllm_args

@endpoint
async def get_model_params(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this function purely for testing, or we plan to leave it in for the final implementation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in theory just for testing, but I think we need to leave it in, because I don't think there is another way to get the loaded params back from vllm to the test for comparison with the saved state dict

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a method that can be patched into the actor class in the test? For example you can do a TestPolicy(Policy) and then add this method.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ya I can't seem to get this to work. I've tried what you suggested, and patching it in different ways, but nothing seems to work.

model = self.worker.model_runner.model
state_dict = {}

for name, param in model.named_parameters():
if "layers.0" not in name:
continue
state_dict[name] = param.cpu().detach()
return state_dict

def setup_worker(self):
"""Build and Instantiate vLLM worker"""
parallel_config = self.vllm_args.parallel_config
Expand Down
167 changes: 167 additions & 0 deletions src/forge/data/sharding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from abc import ABC, abstractmethod

import torch


class BaseSharding(ABC):
"""
Abstract base class for tensor parallel sharding strategies.
"""

def __init__(self, tensor_parallel_size: int, rank: int):
self.tensor_parallel_size = tensor_parallel_size
self.rank = rank

@abstractmethod
def load_from_source_to_target(
self,
param_name: str,
source_tensor: torch.Tensor,
target_tensor: torch.Tensor,
) -> None:
"""
Copy a source tensor to a target tensor, handling sharding and replication.

Args:
param_name: Name of the parameter being loaded
source_tensor: Source tensor to load from
target_tensor: Target tensor to load into
"""
pass


class Llama3vLLMSharding(BaseSharding):
"""
Llama3 vLLM specific tensor parallel sharding strategy.
"""

def load_from_source_to_target(
self,
param_name: str,
source_tensor: torch.Tensor,
target_tensor: torch.Tensor,
) -> None:
"""
Copy a source tensor to a target tensor, handling sharding and replication.
"""
# Determine sharding strategy for this parameter
shard_dim, is_sharded = self._get_tensor_parallel_sharding_strategy(param_name)

if not is_sharded:
# Parameter is replicated - shapes should match exactly
if source_tensor.shape != target_tensor.shape:
raise ValueError(
f"Replicated parameter {param_name} has mismatched shapes: "
f"{source_tensor.shape} vs {target_tensor.shape}, skipping"
)

# Direct copy for replicated parameters
target_tensor.copy_(source_tensor)
else:
# Need to shard the full tensor
sharded_tensor = self._calculate_tensor_shard(
source_tensor, shard_dim, self.tensor_parallel_size, self.rank
)

if sharded_tensor.shape != target_tensor.shape:
raise ValueError(
f"Calculated shard for {param_name} has wrong shape: "
f"{sharded_tensor.shape} vs expected {target_tensor.shape}, skipping"
)

target_tensor.copy_(sharded_tensor)

def _get_tensor_parallel_sharding_strategy(
self, param_name: str
) -> tuple[int, bool]:
"""
Determine the sharding strategy for a parameter in tensor parallel setup.

Returns:
tuple[int, bool]: (shard_dimension, is_sharded)
- shard_dimension: Which dimension to shard (0 or 1)
- is_sharded: Whether this parameter should be sharded at all

Based on vLLM's tensor parallel implementation for LLaMA models:
- Embedding layers: shard along vocab dimension (dim 0)
- Attention projections: qkv_proj shard along hidden dimension (dim 0), o_proj along input dimension (dim 1)
- MLP projections: gate/up_proj shard along hidden dimension (dim 0), down_proj along input dimension (dim 1)
- Layer norms: not sharded (replicated)
- Output layer: shard along vocab dimension (dim 0)
"""
# Parameters that are not sharded (replicated across all tensor parallel ranks)
if any(keyword in param_name for keyword in ["norm", "bias", "rotary_emb"]):
return 0, False

# Embedding layers - shard along vocab dimension (dim 0)
if "embed_tokens" in param_name or "lm_head" in param_name:
return 0, True

# Attention projections
if "qkv_proj" in param_name:
# Input projections: shard output dimension (dim 0)
return 0, True
elif "o_proj" in param_name:
# Output projection: shard input dimension (dim 1)
return 1, True

# MLP projections
elif any(
proj in param_name for proj in ["gate_proj", "up_proj", "gate_up_proj"]
):
# Input projections: shard output dimension (dim 0)
return 0, True
elif "down_proj" in param_name:
# Output projection: shard input dimension (dim 1)
return 1, True

# Default: try to infer from tensor shape patterns
return 0, True

def _calculate_tensor_shard(
self,
full_tensor: torch.Tensor,
shard_dim: int,
tensor_parallel_size: int,
rank: int,
) -> torch.Tensor:
"""
Calculate the shard of a full tensor for the current tensor parallel rank.

Args:
full_tensor: The full tensor to shard
shard_dim: Which dimension to shard along (0 or 1)
tensor_parallel_size: Number of tensor parallel ranks
rank: Current rank (will be modulo by tensor_parallel_size)

Returns:
torch.Tensor: The sharded tensor for this rank
"""
tp_rank = rank % tensor_parallel_size
tensor_size = full_tensor.shape[shard_dim]

if tensor_size % tensor_parallel_size != 0:
raise ValueError(
f"Cannot shard tensor dimension {shard_dim} with size {tensor_size} "
f"across {tensor_parallel_size} ranks: not evenly divisible"
)

shard_size = tensor_size // tensor_parallel_size
start_idx = tp_rank * shard_size
end_idx = start_idx + shard_size

# Create index tensor for the shard range
indices = torch.arange(start_idx, end_idx, device=full_tensor.device)

if shard_dim == 0:
return torch.index_select(full_tensor, 0, indices)
elif shard_dim == 1:
return torch.index_select(full_tensor, 1, indices)
else:
raise ValueError(f"Unsupported shard dimension: {shard_dim}")
Loading
Loading