Skip to content
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
46153ea
initial testing
Aug 13, 2025
a0fc785
more testing
Aug 14, 2025
9b6fa9f
init works
Aug 14, 2025
f26d829
use instruct in path
Aug 14, 2025
8f25f61
somewhat working
Aug 14, 2025
006f27e
kinda working, memory/timeout issue
Aug 14, 2025
10cce6b
store and load working!
Aug 14, 2025
1e4205c
clean up logging
Aug 15, 2025
d8de194
sharded working
Aug 15, 2025
aa916eb
it's working? but _get_tensor_parallel_sharding_strategy is hacky:
Aug 15, 2025
32f1683
it's working
Aug 15, 2025
a39444d
some cleanups
Aug 18, 2025
55c6a49
more clean up
Aug 18, 2025
52bbf3b
clean ups
Aug 18, 2025
082b138
get rid of if else logic
Aug 18, 2025
44caf68
mapping
Aug 18, 2025
e69dbcd
mostly working
Aug 19, 2025
08ba23e
mostly working 2
Aug 19, 2025
c5dd764
mostly working 3
Aug 19, 2025
4743217
single test passes
Aug 19, 2025
dd36d73
single and fsdp works with calculated sharding
Aug 20, 2025
ac6a212
convert from script to test
Aug 20, 2025
b944a2e
cleaning things up
Aug 20, 2025
8bb9710
more cleaning up
Aug 20, 2025
8d029f5
move sharding to helper
Aug 20, 2025
a3355f5
move sharding to helper 2
Aug 20, 2025
6fed9b6
refactor
Aug 20, 2025
6e36dd3
use sharding class in policy and test
Aug 20, 2025
a78be1b
renames
Aug 20, 2025
300fe86
use test fixture
Aug 20, 2025
6003b12
use helper in test
Aug 20, 2025
ec07ba9
remove extra comments
Aug 20, 2025
e0a1797
remove extra load
Aug 20, 2025
5af98a1
clean up prints
Aug 20, 2025
00c4a03
requested changes
Aug 21, 2025
d0fb772
requested changes 2
Aug 21, 2025
bdd2507
use remote dir
Aug 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 202 additions & 9 deletions src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@

import torch
from monarch.actor import Actor, current_rank, endpoint, proc_mesh
from torchstore import MultiProcessStore

from torchstore._state_dict_utils import DELIM, get_state_dict, MAPPING

from vllm.engine.arg_utils import EngineArgs
from vllm.entrypoints.utils import _validate_truncation_size
Expand Down Expand Up @@ -169,6 +172,8 @@ class Policy(Actor):
enforce_eager: bool = False
vllm_args: EngineArgs = None
resources: int = 1
torchstore: MultiProcessStore = None
state_dict_key: str = "model_state_dict"

def __post_init__(self):
"""Build vLLM Arguments
Expand All @@ -190,6 +195,7 @@ def __post_init__(self):
tensor_parallel_size=self.tensor_parallel_size,
pipeline_parallel_size=self.pipeline_parallel_size,
enforce_eager=self.enforce_eager,
gpu_memory_utilization=0.4,
)
# Original method returns False when not run in the main thread
self.vllm_args._is_v1_supported_oracle = lambda *_: True
Expand Down Expand Up @@ -222,10 +228,157 @@ async def setup(self):
async def execute_model(self, schedule: SchedulerOutput):
return self.worker.execute_model(schedule)

def _get_tensor_parallel_sharding_strategy(
self, param_name: str
) -> tuple[int, bool]:
"""
Determine the sharding strategy for a parameter in tensor parallel setup.

Returns:
tuple[int, bool]: (shard_dimension, is_sharded)
- shard_dimension: Which dimension to shard (0 or 1)
- is_sharded: Whether this parameter should be sharded at all

Based on vLLM's tensor parallel implementation for LLaMA models:
- Embedding layers: shard along vocab dimension (dim 0)
- Attention projections: qk/_proj shard along hidden dimension (dim 0), o_proj along input dimension (dim 1)
- MLP projections: gate/up_proj shard along hidden dimension (dim 0), down_proj along input dimension (dim 1)
- Layer norms: not sharded (replicated)
- Output layer: shard along vocab dimension (dim 0)
"""
# Parameters that are not sharded (replicated across all tensor parallel ranks)
if any(keyword in param_name for keyword in ["norm", "bias", "rotary_emb"]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imo we should associate logic like this with the model somehow rather than make it a fixed property of the Policy class. Happy to brainstorm a bit more on the right way to do this (also I assume the TP strategy here is unique to vLLM and does not in general match what's defined in titan?)

return 0, False

# Embedding layers - shard along vocab dimension (dim 0)
if "embed_tokens" in param_name or "lm_head" in param_name:
return 0, True

# Attention projections
if "qkv_proj" in param_name:
# Input projections: shard output dimension (dim 0)
return 0, True
elif "o_proj" in param_name:
# Output projection: shard input dimension (dim 1)
return 1, True

# MLP projections
elif any(proj in param_name for proj in ["gate_proj", "up_proj"]):
# Input projections: shard output dimension (dim 0)
return 0, True
elif "down_proj" in param_name:
# Output projection: shard input dimension (dim 1)
return 1, True

# Default: try to infer from tensor shape patterns
return 0, True

def _calculate_tensor_shard(
self, full_tensor: torch.Tensor, shard_dim: int
) -> torch.Tensor:
"""
Calculate the shard of a full tensor for the current tensor parallel rank.

Args:
full_tensor: The full tensor to shard
shard_dim: Which dimension to shard along (0 or 1)

Returns:
torch.Tensor: The sharded tensor for this rank
"""
tp_rank = self.rank % self.tensor_parallel_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As per the impl, we do;

  • even sharding
  • placement on every rank.
    Probably good to document the contract/policy.

tensor_size = full_tensor.shape[shard_dim]

if tensor_size % self.tensor_parallel_size != 0:
raise ValueError(
f"Cannot shard tensor dimension {shard_dim} with size {tensor_size} "
f"across {self.tensor_parallel_size} ranks: not evenly divisible"
)

shard_size = tensor_size // self.tensor_parallel_size
start_idx = tp_rank * shard_size
end_idx = start_idx + shard_size

if shard_dim == 0:
return full_tensor[start_idx:end_idx]
elif shard_dim == 1:
return full_tensor[:, start_idx:end_idx]
else:
raise ValueError(f"Unsupported shard dimension: {shard_dim}")

async def _load_tensor_parallel_state_dict(self, current_state_dict: dict):
"""
Load full state dict from torchstore into tensor parallel model with deterministic sharding.
"""

updated_count = 0

for param_name in current_state_dict.keys():
current_tensor = current_state_dict[param_name]

# Load the full tensor from torchstore
stored_tensor = await self.torchstore.get(
f"{self.state_dict_key}{DELIM}{param_name}"
)

# Determine sharding strategy for this parameter
shard_dim, is_sharded = self._get_tensor_parallel_sharding_strategy(
param_name
)

if not is_sharded:
# Parameter is replicated - shapes should match exactly
if stored_tensor.shape != current_tensor.shape:
raise ValueError(
f"Replicated parameter {param_name} has mismatched shapes: "
f"{stored_tensor.shape} vs {current_tensor.shape}, skipping"
)

# Direct copy for replicated parameters
current_state_dict[param_name].copy_(stored_tensor)

else:
# Need to shard the full tensor
sharded_tensor = self._calculate_tensor_shard(stored_tensor, shard_dim)

if sharded_tensor.shape != current_tensor.shape:
raise ValueError(
f"Calculated shard for {param_name} has wrong shape: "
f"{sharded_tensor.shape} vs expected {current_tensor.shape}, skipping"
)

current_state_dict[param_name].copy_(sharded_tensor)

updated_count += 1

logger.info(f"Successfully updated {updated_count} parameters")

@endpoint
async def update(self):
# TODO: add TorchStore support
pass
"""Update model weights by reading state dict from torchstore"""

if self.torchstore is None:
raise Exception("No torchstore configured, skipping model update")

logger.info(
f"Starting model update from torchstore with key: {self.state_dict_key}"
)

# Get the current model from the worker
model = self.worker.model_runner.model
current_state_dict = model.state_dict()

logger.info(f"Current state dict has {len(current_state_dict)} parameters")
logger.info(f"Tensor parallel size: {self.tensor_parallel_size}")

# Tensor parallel model - use deterministic sharding strategy
logger.info("Loading state dict with tensor parallel sharding...")
await self._load_tensor_parallel_state_dict(current_state_dict)

# Load the updated state dict into the model
model.load_state_dict(current_state_dict, strict=True)

logger.info("Successfully updated model weights from torchstore")

@endpoint
async def setup_kv_cache(self):
Expand Down Expand Up @@ -261,27 +414,67 @@ async def setup_kv_cache(self):
async def get_vllm_args(self):
return self.vllm_args

@endpoint
async def get_model_params(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this function purely for testing, or we plan to leave it in for the final implementation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in theory just for testing, but I think we need to leave it in, because I don't think there is another way to get the loaded params back from vllm to the test for comparison with the saved state dict

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a method that can be patched into the actor class in the test? For example you can do a TestPolicy(Policy) and then add this method.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ya I can't seem to get this to work. I've tried what you suggested, and patching it in different ways, but nothing seems to work.

model = self.worker.model_runner.model
state_dict = {}

for name, param in model.named_parameters():
# only use one layer for testing, otherwise it's too slow
if "layers.0" in name:
state_dict[name] = param.cpu().detach()
return state_dict

def setup_worker(self):
"""Build and Instantiate vLLM worker"""
parallel_config = self.vllm_args.parallel_config
set_multiprocessing_worker_envs(parallel_config)

# Get distributed init info
ip, port = os.getenv("MASTER_ADDR"), os.getenv("MASTER_PORT")
distributed_init_method = get_distributed_init_method(ip, port)
all_kwargs = [{}] * parallel_config.world_size
local_rank = self.rank % torch.accelerator.device_count()

# Calculate local rank properly
device_count = torch.cuda.device_count() if torch.cuda.is_available() else 1
local_rank = self.rank % device_count

# Validate local rank
if local_rank >= device_count:
raise ValueError(
f"Local rank {local_rank} exceeds available devices {device_count}"
)

# Calculate driver worker properly
is_driver_worker = self.rank % parallel_config.tensor_parallel_size == 0

# Prepare worker kwargs
all_kwargs = [{}] * parallel_config.world_size
all_kwargs[self.rank] = {
"vllm_config": self.vllm_args,
"local_rank": local_rank,
"rank": self.rank,
"distributed_init_method": distributed_init_method,
"is_driver_worker": is_driver_worker,
}
worker = WorkerWrapperBase(self.vllm_args, self.rank)
worker.init_worker(all_kwargs)
worker.init_device()
worker.load_model()
return worker

logger.info(
f"Setting up worker: rank={self.rank}, local_rank={local_rank}, "
f"is_driver={is_driver_worker}, device_count={device_count}"
)

try:
worker = WorkerWrapperBase(self.vllm_args, self.rank)
worker.init_worker(all_kwargs)
worker.init_device()
worker.load_model()
return worker
except Exception as e:
logger.error(f"Failed to setup worker: {e}")
logger.error(
f"Worker config: rank={self.rank}, local_rank={local_rank}, "
f"device_count={device_count}, world_size={parallel_config.world_size}"
)
raise


def convert_input(prompt=None, prompt_token_ids=None):
Expand Down
Loading
Loading