Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
46153ea
initial testing
Aug 13, 2025
a0fc785
more testing
Aug 14, 2025
9b6fa9f
init works
Aug 14, 2025
f26d829
use instruct in path
Aug 14, 2025
8f25f61
somewhat working
Aug 14, 2025
006f27e
kinda working, memory/timeout issue
Aug 14, 2025
10cce6b
store and load working!
Aug 14, 2025
1e4205c
clean up logging
Aug 15, 2025
d8de194
sharded working
Aug 15, 2025
aa916eb
it's working? but _get_tensor_parallel_sharding_strategy is hacky:
Aug 15, 2025
32f1683
it's working
Aug 15, 2025
a39444d
some cleanups
Aug 18, 2025
55c6a49
more clean up
Aug 18, 2025
52bbf3b
clean ups
Aug 18, 2025
082b138
get rid of if else logic
Aug 18, 2025
44caf68
mapping
Aug 18, 2025
e69dbcd
mostly working
Aug 19, 2025
08ba23e
mostly working 2
Aug 19, 2025
c5dd764
mostly working 3
Aug 19, 2025
4743217
single test passes
Aug 19, 2025
dd36d73
single and fsdp works with calculated sharding
Aug 20, 2025
ac6a212
convert from script to test
Aug 20, 2025
b944a2e
cleaning things up
Aug 20, 2025
8bb9710
more cleaning up
Aug 20, 2025
8d029f5
move sharding to helper
Aug 20, 2025
a3355f5
move sharding to helper 2
Aug 20, 2025
6fed9b6
refactor
Aug 20, 2025
6e36dd3
use sharding class in policy and test
Aug 20, 2025
a78be1b
renames
Aug 20, 2025
300fe86
use test fixture
Aug 20, 2025
6003b12
use helper in test
Aug 20, 2025
ec07ba9
remove extra comments
Aug 20, 2025
e0a1797
remove extra load
Aug 20, 2025
5af98a1
clean up prints
Aug 20, 2025
00c4a03
requested changes
Aug 21, 2025
d0fb772
requested changes 2
Aug 21, 2025
bdd2507
use remote dir
Aug 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 140 additions & 2 deletions src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

import torch
from monarch.actor import Actor, current_rank, endpoint, proc_mesh
from torchstore import MultiProcessStore
from torchstore._state_dict_utils import get_state_dict

from vllm.engine.arg_utils import EngineArgs
from vllm.entrypoints.utils import _validate_truncation_size
Expand Down Expand Up @@ -169,6 +171,8 @@ class Policy(Actor):
enforce_eager: bool = False
vllm_args: EngineArgs = None
resources: int = 1
torchstore: MultiProcessStore = None
state_dict_key: str = "model_state_dict"

def __post_init__(self):
"""Build vLLM Arguments
Expand Down Expand Up @@ -222,10 +226,117 @@ async def setup(self):
async def execute_model(self, schedule: SchedulerOutput):
return self.worker.execute_model(schedule)

async def _load_tensor_parallel_state_dict(self, current_state_dict: dict):
"""
Load full state dict from torchstore into tensor parallel model.
Uses DTensor's distribution system when available for automatic sharding.
"""
from torchstore._state_dict_utils import DELIM, MAPPING

# Get the mapping of stored parameters
try:
fetched_mapping = await self.torchstore.get(
f"{self.state_dict_key}{DELIM}{MAPPING}"
)
except Exception as e:
raise RuntimeError(
f"Could not load mapping for state dict key {self.state_dict_key}: {e}"
)

logger.info(f"Loading {len(fetched_mapping)} parameters with tensor parallel support")

updated_count = 0

for param_name in fetched_mapping.keys():
if param_name not in current_state_dict:
logger.warning(f"Parameter {param_name} not found in current model, skipping")
continue

current_tensor = current_state_dict[param_name]

try:
# Load the full tensor from torchstore
stored_tensor = await self.torchstore.get(f"{self.state_dict_key}{DELIM}{param_name}")

# Check if the current tensor is a DTensor
if hasattr(current_tensor, '_spec') and current_tensor._spec is not None:
# This is a DTensor - use DTensor's distribution system
logger.debug(f"Distributing DTensor parameter {param_name} with spec: {current_tensor._spec}")

try:
from torch.distributed._tensor import distribute_tensor

# Get the DTensor's distribution spec
device_mesh = current_tensor.device_mesh
placements = current_tensor._spec.placements

# Distribute the stored tensor according to the current tensor's spec
distributed_tensor = distribute_tensor(stored_tensor, device_mesh, placements)

# Copy the local shard to the current tensor
current_state_dict[param_name].copy_(distributed_tensor._local_tensor)
logger.debug(f"Successfully distributed DTensor parameter {param_name}")

except Exception as dtensor_e:
logger.warning(f"Failed to distribute DTensor {param_name}: {dtensor_e}")
continue

else:
# Regular tensor - direct copy (should have matching shapes)
if stored_tensor.shape != current_tensor.shape:
if stored_tensor.shape != current_tensor.shape:
raise RuntimeError(
f"Shape mismatch for regular tensor {param_name}: {stored_tensor.shape} vs {current_tensor.shape}"
)

current_state_dict[param_name].copy_(stored_tensor)
logger.debug(f"Copied regular parameter {param_name}")

updated_count += 1

except Exception as e:
logger.warning(f"Failed to load parameter {param_name}: {e}")
continue

logger.info(f"Successfully updated {updated_count} parameters")

if updated_count == 0:
raise RuntimeError("No parameters were successfully updated")

@endpoint
async def update(self):
# TODO: add TorchStore support
pass
"""Update model weights by reading state dict from torchstore"""
if self.torchstore is None:
logger.warning("No torchstore configured, skipping model update")
return False

try:

# Get the current model from the worker
model = self.worker.model_runner.model
current_state_dict = model.state_dict()

if self.tensor_parallel_size > 1:
logger.info("Loading state dict with tensor parallel sharding")
await self._load_tensor_parallel_state_dict(current_state_dict)
else:
logger.info("Loading state dict for single GPU model")
await get_state_dict(
self.torchstore, self.state_dict_key, current_state_dict
)

# Load the updated state dict into the model
model.load_state_dict(current_state_dict, strict=True)

logger.info("Successfully updated model weights from torchstore")
return True

except Exception as e:
logger.error(f"Failed to update model from torchstore: {e}")
import traceback

logger.error(f"Traceback: {traceback.format_exc()}")
return False

@endpoint
async def setup_kv_cache(self):
Expand Down Expand Up @@ -261,6 +372,33 @@ async def setup_kv_cache(self):
async def get_vllm_args(self):
return self.vllm_args

@endpoint
async def test_model_info(self):
"""Get basic model information for testing purposes"""
import torch

model = self.worker.model_runner.model

# Get basic model info that doesn't require forward pass
model_info = {
"num_parameters": sum(p.numel() for p in model.parameters()),
"device": str(next(model.parameters()).device),
"dtype": str(next(model.parameters()).dtype),
"model_type": type(model).__name__,
}

# Get a sample of parameter values for comparison
# Use the embedding layer weights as they're typically the first parameters
for name, param in model.named_parameters():
if "embed" in name.lower() and param.numel() >= 10:
# Convert to float32 before numpy conversion to handle BFloat16
sample_weights = param.flatten()[:10].cpu().detach().float()
model_info["sample_weights"] = sample_weights.numpy().tolist()
model_info["sample_param_name"] = name
break

return model_info

def setup_worker(self):
"""Build and Instantiate vLLM worker"""
parallel_config = self.vllm_args.parallel_config
Expand Down
Loading