Skip to content

Commit e69dbcd

Browse files
author
ankitageorge
committed
mostly working
1 parent 44caf68 commit e69dbcd

File tree

2 files changed

+186
-90
lines changed

2 files changed

+186
-90
lines changed

src/forge/actors/policy.py

Lines changed: 179 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
from vllm.v1.structured_output import StructuredOutputManager
3535
from vllm.worker.worker_base import WorkerWrapperBase
3636

37+
from torchstore._state_dict_utils import DELIM, MAPPING
38+
3739
logger = logging.getLogger(__name__)
3840

3941

@@ -194,7 +196,7 @@ def __post_init__(self):
194196
tensor_parallel_size=self.tensor_parallel_size,
195197
pipeline_parallel_size=self.pipeline_parallel_size,
196198
enforce_eager=self.enforce_eager,
197-
gpu_memory_utilization=0.7,
199+
gpu_memory_utilization=0.4,
198200
)
199201
# Original method returns False when not run in the main thread
200202
self.vllm_args._is_v1_supported_oracle = lambda *_: True
@@ -227,38 +229,156 @@ async def setup(self):
227229
async def execute_model(self, schedule: SchedulerOutput):
228230
return self.worker.execute_model(schedule)
229231

232+
def _get_tensor_parallel_sharding_strategy(self, param_name: str) -> tuple[int, bool]:
233+
"""
234+
Determine the sharding strategy for a parameter in tensor parallel setup.
235+
236+
Returns:
237+
tuple[int, bool]: (shard_dimension, is_sharded)
238+
- shard_dimension: Which dimension to shard (0 or 1)
239+
- is_sharded: Whether this parameter should be sharded at all
240+
241+
Based on vLLM's tensor parallel implementation for LLaMA models:
242+
- Embedding layers: shard along vocab dimension (dim 0)
243+
- Attention projections: qk/_proj shard along hidden dimension (dim 0), o_proj along input dimension (dim 1)
244+
- MLP projections: gate/up_proj shard along hidden dimension (dim 0), down_proj along input dimension (dim 1)
245+
- Layer norms: not sharded (replicated)
246+
- Output layer: shard along vocab dimension (dim 0)
247+
"""
248+
# Parameters that are not sharded (replicated across all tensor parallel ranks)
249+
if any(keyword in param_name for keyword in [
250+
'norm', 'bias', 'rotary_emb'
251+
]):
252+
return 0, False
253+
254+
# Embedding layers - shard along vocab dimension (dim 0)
255+
if 'embed_tokens' in param_name or 'lm_head' in param_name:
256+
return 0, True
257+
258+
# Attention projections
259+
if 'qkv_proj' in param_name:
260+
# Input projections: shard output dimension (dim 0)
261+
return 0, True
262+
elif 'o_proj' in param_name:
263+
# Output projection: shard input dimension (dim 1)
264+
return 1, True
265+
266+
# MLP projections
267+
elif any(proj in param_name for proj in ['gate_proj', 'up_proj']):
268+
# Input projections: shard output dimension (dim 0)
269+
return 0, True
270+
elif 'down_proj' in param_name:
271+
# Output projection: shard input dimension (dim 1)
272+
return 1, True
273+
274+
# Default: try to infer from tensor shape patterns
275+
return 0, True
276+
277+
def _calculate_tensor_shard(self, full_tensor: torch.Tensor, shard_dim: int) -> torch.Tensor:
278+
"""
279+
Calculate the shard of a full tensor for the current tensor parallel rank.
280+
281+
Args:
282+
full_tensor: The full tensor to shard
283+
shard_dim: Which dimension to shard along (0 or 1)
284+
285+
Returns:
286+
torch.Tensor: The sharded tensor for this rank
287+
"""
288+
tp_rank = self.rank % self.tensor_parallel_size
289+
tensor_size = full_tensor.shape[shard_dim]
290+
291+
if tensor_size % self.tensor_parallel_size != 0:
292+
raise ValueError(
293+
f"Cannot shard tensor dimension {shard_dim} with size {tensor_size} "
294+
f"across {self.tensor_parallel_size} ranks: not evenly divisible"
295+
)
296+
297+
shard_size = tensor_size // self.tensor_parallel_size
298+
start_idx = tp_rank * shard_size
299+
end_idx = start_idx + shard_size
300+
301+
if shard_dim == 0:
302+
return full_tensor[start_idx:end_idx]
303+
elif shard_dim == 1:
304+
return full_tensor[:, start_idx:end_idx]
305+
else:
306+
raise ValueError(f"Unsupported shard dimension: {shard_dim}")
307+
308+
async def _load_tensor_parallel_state_dict(self, current_state_dict: dict):
309+
"""
310+
Load full state dict from torchstore into tensor parallel model with deterministic sharding.
311+
"""
312+
313+
updated_count = 0
314+
315+
for param_name in current_state_dict.keys():
316+
current_tensor = current_state_dict[param_name]
317+
318+
# Load the full tensor from torchstore
319+
stored_tensor = await self.torchstore.get(f"{self.state_dict_key}{DELIM}{param_name}")
320+
321+
# Determine sharding strategy for this parameter
322+
shard_dim, is_sharded = self._get_tensor_parallel_sharding_strategy(param_name)
323+
324+
if not is_sharded:
325+
# Parameter is replicated - shapes should match exactly
326+
if stored_tensor.shape != current_tensor.shape:
327+
raise ValueError(
328+
f"Replicated parameter {param_name} has mismatched shapes: "
329+
f"{stored_tensor.shape} vs {current_tensor.shape}, skipping"
330+
)
331+
332+
# Direct copy for replicated parameters
333+
current_state_dict[param_name].copy_(stored_tensor)
334+
335+
else:
336+
# Need to shard the full tensor
337+
sharded_tensor = self._calculate_tensor_shard(stored_tensor, shard_dim)
338+
339+
if sharded_tensor.shape != current_tensor.shape:
340+
raise ValueError(
341+
f"Calculated shard for {param_name} has wrong shape: "
342+
f"{sharded_tensor.shape} vs expected {current_tensor.shape}, skipping"
343+
)
344+
345+
current_state_dict[param_name].copy_(sharded_tensor)
346+
347+
updated_count += 1
348+
349+
logger.info(f"Successfully updated {updated_count} parameters")
350+
230351
@endpoint
231352
async def update(self):
232353
"""Update model weights by reading state dict from torchstore"""
354+
233355
if self.torchstore is None:
234-
logger.warning("No torchstore configured, skipping model update")
235-
return False
356+
raise Exception("No torchstore configured, skipping model update")
357+
236358

237-
from torchstore._state_dict_utils import DELIM
359+
logger.info(f"Starting model update from torchstore with key: {self.state_dict_key}")
238360

239361
# Get the current model from the worker
240362
model = self.worker.model_runner.model
241363
current_state_dict = model.state_dict()
242364

243-
updated_count = 0
244-
# Iterate through each parameter in current state dict and load directly using torchstore.get
245-
for param_name, current_tensor in current_state_dict.items():
246-
# Use torchstore.get to load directly into the current tensor
247-
# This automatically handles both tensor parallelized and regular tensors
248-
try:
249-
await self.torchstore.get(
250-
f"{self.state_dict_key}{DELIM}{param_name}",
251-
current_tensor,
252-
)
253-
logger.info(f"Successfully updated {param_name} from torchstore")
254-
updated_count += 1
255-
except Exception as e:
256-
logger.error(
257-
f"Failed to load parameter {param_name} from torchstore: {e}"
258-
)
259-
continue
260-
261-
logger.info(f"Successfully updated {updated_count} parameters from torchstore")
365+
logger.info(f"Current state dict has {len(current_state_dict)} parameters")
366+
logger.info(f"Tensor parallel size: {self.tensor_parallel_size}")
367+
368+
if self.tensor_parallel_size > 1:
369+
# Tensor parallel model - use deterministic sharding strategy
370+
logger.info("Loading state dict with tensor parallel sharding...")
371+
await self._load_tensor_parallel_state_dict(current_state_dict)
372+
else:
373+
# Single GPU model - use standard loading
374+
logger.info("Loading state dict for single GPU model...")
375+
await get_state_dict(self.torchstore, self.state_dict_key, current_state_dict)
376+
377+
# Load the updated state dict into the model
378+
model.load_state_dict(current_state_dict, strict=True)
379+
380+
logger.info("Successfully updated model weights from torchstore")
381+
262382

263383
@endpoint
264384
async def setup_kv_cache(self):
@@ -297,7 +417,6 @@ async def get_vllm_args(self):
297417
@endpoint
298418
async def test_model_info(self):
299419
"""Get basic model information for testing purposes"""
300-
import torch
301420

302421
model = self.worker.model_runner.model
303422

@@ -325,23 +444,52 @@ def setup_worker(self):
325444
"""Build and Instantiate vLLM worker"""
326445
parallel_config = self.vllm_args.parallel_config
327446
set_multiprocessing_worker_envs(parallel_config)
447+
448+
# Get distributed init info
328449
ip, port = os.getenv("MASTER_ADDR"), os.getenv("MASTER_PORT")
329450
distributed_init_method = get_distributed_init_method(ip, port)
330-
all_kwargs = [{}] * parallel_config.world_size
331-
local_rank = self.rank % torch.accelerator.device_count()
451+
452+
# Calculate local rank properly
453+
device_count = torch.cuda.device_count() if torch.cuda.is_available() else 1
454+
local_rank = self.rank % device_count
455+
456+
# Validate local rank
457+
if local_rank >= device_count:
458+
raise ValueError(
459+
f"Local rank {local_rank} exceeds available devices {device_count}"
460+
)
461+
462+
# Calculate driver worker properly
332463
is_driver_worker = self.rank % parallel_config.tensor_parallel_size == 0
464+
465+
# Prepare worker kwargs
466+
all_kwargs = [{}] * parallel_config.world_size
333467
all_kwargs[self.rank] = {
334468
"vllm_config": self.vllm_args,
335469
"local_rank": local_rank,
336470
"rank": self.rank,
337471
"distributed_init_method": distributed_init_method,
338472
"is_driver_worker": is_driver_worker,
339473
}
340-
worker = WorkerWrapperBase(self.vllm_args, self.rank)
341-
worker.init_worker(all_kwargs)
342-
worker.init_device()
343-
worker.load_model()
344-
return worker
474+
475+
logger.info(
476+
f"Setting up worker: rank={self.rank}, local_rank={local_rank}, "
477+
f"is_driver={is_driver_worker}, device_count={device_count}"
478+
)
479+
480+
try:
481+
worker = WorkerWrapperBase(self.vllm_args, self.rank)
482+
worker.init_worker(all_kwargs)
483+
worker.init_device()
484+
worker.load_model()
485+
return worker
486+
except Exception as e:
487+
logger.error(f"Failed to setup worker: {e}")
488+
logger.error(
489+
f"Worker config: rank={self.rank}, local_rank={local_rank}, "
490+
f"device_count={device_count}, world_size={parallel_config.world_size}"
491+
)
492+
raise
345493

346494

347495
def convert_input(prompt=None, prompt_token_ids=None):

0 commit comments

Comments
 (0)