|
34 | 34 | from vllm.v1.structured_output import StructuredOutputManager |
35 | 35 | from vllm.worker.worker_base import WorkerWrapperBase |
36 | 36 |
|
| 37 | +from torchstore._state_dict_utils import DELIM, MAPPING |
| 38 | + |
37 | 39 | logger = logging.getLogger(__name__) |
38 | 40 |
|
39 | 41 |
|
@@ -194,7 +196,7 @@ def __post_init__(self): |
194 | 196 | tensor_parallel_size=self.tensor_parallel_size, |
195 | 197 | pipeline_parallel_size=self.pipeline_parallel_size, |
196 | 198 | enforce_eager=self.enforce_eager, |
197 | | - gpu_memory_utilization=0.7, |
| 199 | + gpu_memory_utilization=0.4, |
198 | 200 | ) |
199 | 201 | # Original method returns False when not run in the main thread |
200 | 202 | self.vllm_args._is_v1_supported_oracle = lambda *_: True |
@@ -227,38 +229,156 @@ async def setup(self): |
227 | 229 | async def execute_model(self, schedule: SchedulerOutput): |
228 | 230 | return self.worker.execute_model(schedule) |
229 | 231 |
|
| 232 | + def _get_tensor_parallel_sharding_strategy(self, param_name: str) -> tuple[int, bool]: |
| 233 | + """ |
| 234 | + Determine the sharding strategy for a parameter in tensor parallel setup. |
| 235 | + |
| 236 | + Returns: |
| 237 | + tuple[int, bool]: (shard_dimension, is_sharded) |
| 238 | + - shard_dimension: Which dimension to shard (0 or 1) |
| 239 | + - is_sharded: Whether this parameter should be sharded at all |
| 240 | + |
| 241 | + Based on vLLM's tensor parallel implementation for LLaMA models: |
| 242 | + - Embedding layers: shard along vocab dimension (dim 0) |
| 243 | + - Attention projections: qk/_proj shard along hidden dimension (dim 0), o_proj along input dimension (dim 1) |
| 244 | + - MLP projections: gate/up_proj shard along hidden dimension (dim 0), down_proj along input dimension (dim 1) |
| 245 | + - Layer norms: not sharded (replicated) |
| 246 | + - Output layer: shard along vocab dimension (dim 0) |
| 247 | + """ |
| 248 | + # Parameters that are not sharded (replicated across all tensor parallel ranks) |
| 249 | + if any(keyword in param_name for keyword in [ |
| 250 | + 'norm', 'bias', 'rotary_emb' |
| 251 | + ]): |
| 252 | + return 0, False |
| 253 | + |
| 254 | + # Embedding layers - shard along vocab dimension (dim 0) |
| 255 | + if 'embed_tokens' in param_name or 'lm_head' in param_name: |
| 256 | + return 0, True |
| 257 | + |
| 258 | + # Attention projections |
| 259 | + if 'qkv_proj' in param_name: |
| 260 | + # Input projections: shard output dimension (dim 0) |
| 261 | + return 0, True |
| 262 | + elif 'o_proj' in param_name: |
| 263 | + # Output projection: shard input dimension (dim 1) |
| 264 | + return 1, True |
| 265 | + |
| 266 | + # MLP projections |
| 267 | + elif any(proj in param_name for proj in ['gate_proj', 'up_proj']): |
| 268 | + # Input projections: shard output dimension (dim 0) |
| 269 | + return 0, True |
| 270 | + elif 'down_proj' in param_name: |
| 271 | + # Output projection: shard input dimension (dim 1) |
| 272 | + return 1, True |
| 273 | + |
| 274 | + # Default: try to infer from tensor shape patterns |
| 275 | + return 0, True |
| 276 | + |
| 277 | + def _calculate_tensor_shard(self, full_tensor: torch.Tensor, shard_dim: int) -> torch.Tensor: |
| 278 | + """ |
| 279 | + Calculate the shard of a full tensor for the current tensor parallel rank. |
| 280 | + |
| 281 | + Args: |
| 282 | + full_tensor: The full tensor to shard |
| 283 | + shard_dim: Which dimension to shard along (0 or 1) |
| 284 | + |
| 285 | + Returns: |
| 286 | + torch.Tensor: The sharded tensor for this rank |
| 287 | + """ |
| 288 | + tp_rank = self.rank % self.tensor_parallel_size |
| 289 | + tensor_size = full_tensor.shape[shard_dim] |
| 290 | + |
| 291 | + if tensor_size % self.tensor_parallel_size != 0: |
| 292 | + raise ValueError( |
| 293 | + f"Cannot shard tensor dimension {shard_dim} with size {tensor_size} " |
| 294 | + f"across {self.tensor_parallel_size} ranks: not evenly divisible" |
| 295 | + ) |
| 296 | + |
| 297 | + shard_size = tensor_size // self.tensor_parallel_size |
| 298 | + start_idx = tp_rank * shard_size |
| 299 | + end_idx = start_idx + shard_size |
| 300 | + |
| 301 | + if shard_dim == 0: |
| 302 | + return full_tensor[start_idx:end_idx] |
| 303 | + elif shard_dim == 1: |
| 304 | + return full_tensor[:, start_idx:end_idx] |
| 305 | + else: |
| 306 | + raise ValueError(f"Unsupported shard dimension: {shard_dim}") |
| 307 | + |
| 308 | + async def _load_tensor_parallel_state_dict(self, current_state_dict: dict): |
| 309 | + """ |
| 310 | + Load full state dict from torchstore into tensor parallel model with deterministic sharding. |
| 311 | + """ |
| 312 | + |
| 313 | + updated_count = 0 |
| 314 | + |
| 315 | + for param_name in current_state_dict.keys(): |
| 316 | + current_tensor = current_state_dict[param_name] |
| 317 | + |
| 318 | + # Load the full tensor from torchstore |
| 319 | + stored_tensor = await self.torchstore.get(f"{self.state_dict_key}{DELIM}{param_name}") |
| 320 | + |
| 321 | + # Determine sharding strategy for this parameter |
| 322 | + shard_dim, is_sharded = self._get_tensor_parallel_sharding_strategy(param_name) |
| 323 | + |
| 324 | + if not is_sharded: |
| 325 | + # Parameter is replicated - shapes should match exactly |
| 326 | + if stored_tensor.shape != current_tensor.shape: |
| 327 | + raise ValueError( |
| 328 | + f"Replicated parameter {param_name} has mismatched shapes: " |
| 329 | + f"{stored_tensor.shape} vs {current_tensor.shape}, skipping" |
| 330 | + ) |
| 331 | + |
| 332 | + # Direct copy for replicated parameters |
| 333 | + current_state_dict[param_name].copy_(stored_tensor) |
| 334 | + |
| 335 | + else: |
| 336 | + # Need to shard the full tensor |
| 337 | + sharded_tensor = self._calculate_tensor_shard(stored_tensor, shard_dim) |
| 338 | + |
| 339 | + if sharded_tensor.shape != current_tensor.shape: |
| 340 | + raise ValueError( |
| 341 | + f"Calculated shard for {param_name} has wrong shape: " |
| 342 | + f"{sharded_tensor.shape} vs expected {current_tensor.shape}, skipping" |
| 343 | + ) |
| 344 | + |
| 345 | + current_state_dict[param_name].copy_(sharded_tensor) |
| 346 | + |
| 347 | + updated_count += 1 |
| 348 | + |
| 349 | + logger.info(f"Successfully updated {updated_count} parameters") |
| 350 | + |
230 | 351 | @endpoint |
231 | 352 | async def update(self): |
232 | 353 | """Update model weights by reading state dict from torchstore""" |
| 354 | + |
233 | 355 | if self.torchstore is None: |
234 | | - logger.warning("No torchstore configured, skipping model update") |
235 | | - return False |
| 356 | + raise Exception("No torchstore configured, skipping model update") |
| 357 | + |
236 | 358 |
|
237 | | - from torchstore._state_dict_utils import DELIM |
| 359 | + logger.info(f"Starting model update from torchstore with key: {self.state_dict_key}") |
238 | 360 |
|
239 | 361 | # Get the current model from the worker |
240 | 362 | model = self.worker.model_runner.model |
241 | 363 | current_state_dict = model.state_dict() |
242 | 364 |
|
243 | | - updated_count = 0 |
244 | | - # Iterate through each parameter in current state dict and load directly using torchstore.get |
245 | | - for param_name, current_tensor in current_state_dict.items(): |
246 | | - # Use torchstore.get to load directly into the current tensor |
247 | | - # This automatically handles both tensor parallelized and regular tensors |
248 | | - try: |
249 | | - await self.torchstore.get( |
250 | | - f"{self.state_dict_key}{DELIM}{param_name}", |
251 | | - current_tensor, |
252 | | - ) |
253 | | - logger.info(f"Successfully updated {param_name} from torchstore") |
254 | | - updated_count += 1 |
255 | | - except Exception as e: |
256 | | - logger.error( |
257 | | - f"Failed to load parameter {param_name} from torchstore: {e}" |
258 | | - ) |
259 | | - continue |
260 | | - |
261 | | - logger.info(f"Successfully updated {updated_count} parameters from torchstore") |
| 365 | + logger.info(f"Current state dict has {len(current_state_dict)} parameters") |
| 366 | + logger.info(f"Tensor parallel size: {self.tensor_parallel_size}") |
| 367 | + |
| 368 | + if self.tensor_parallel_size > 1: |
| 369 | + # Tensor parallel model - use deterministic sharding strategy |
| 370 | + logger.info("Loading state dict with tensor parallel sharding...") |
| 371 | + await self._load_tensor_parallel_state_dict(current_state_dict) |
| 372 | + else: |
| 373 | + # Single GPU model - use standard loading |
| 374 | + logger.info("Loading state dict for single GPU model...") |
| 375 | + await get_state_dict(self.torchstore, self.state_dict_key, current_state_dict) |
| 376 | + |
| 377 | + # Load the updated state dict into the model |
| 378 | + model.load_state_dict(current_state_dict, strict=True) |
| 379 | + |
| 380 | + logger.info("Successfully updated model weights from torchstore") |
| 381 | + |
262 | 382 |
|
263 | 383 | @endpoint |
264 | 384 | async def setup_kv_cache(self): |
@@ -297,7 +417,6 @@ async def get_vllm_args(self): |
297 | 417 | @endpoint |
298 | 418 | async def test_model_info(self): |
299 | 419 | """Get basic model information for testing purposes""" |
300 | | - import torch |
301 | 420 |
|
302 | 421 | model = self.worker.model_runner.model |
303 | 422 |
|
@@ -325,23 +444,52 @@ def setup_worker(self): |
325 | 444 | """Build and Instantiate vLLM worker""" |
326 | 445 | parallel_config = self.vllm_args.parallel_config |
327 | 446 | set_multiprocessing_worker_envs(parallel_config) |
| 447 | + |
| 448 | + # Get distributed init info |
328 | 449 | ip, port = os.getenv("MASTER_ADDR"), os.getenv("MASTER_PORT") |
329 | 450 | distributed_init_method = get_distributed_init_method(ip, port) |
330 | | - all_kwargs = [{}] * parallel_config.world_size |
331 | | - local_rank = self.rank % torch.accelerator.device_count() |
| 451 | + |
| 452 | + # Calculate local rank properly |
| 453 | + device_count = torch.cuda.device_count() if torch.cuda.is_available() else 1 |
| 454 | + local_rank = self.rank % device_count |
| 455 | + |
| 456 | + # Validate local rank |
| 457 | + if local_rank >= device_count: |
| 458 | + raise ValueError( |
| 459 | + f"Local rank {local_rank} exceeds available devices {device_count}" |
| 460 | + ) |
| 461 | + |
| 462 | + # Calculate driver worker properly |
332 | 463 | is_driver_worker = self.rank % parallel_config.tensor_parallel_size == 0 |
| 464 | + |
| 465 | + # Prepare worker kwargs |
| 466 | + all_kwargs = [{}] * parallel_config.world_size |
333 | 467 | all_kwargs[self.rank] = { |
334 | 468 | "vllm_config": self.vllm_args, |
335 | 469 | "local_rank": local_rank, |
336 | 470 | "rank": self.rank, |
337 | 471 | "distributed_init_method": distributed_init_method, |
338 | 472 | "is_driver_worker": is_driver_worker, |
339 | 473 | } |
340 | | - worker = WorkerWrapperBase(self.vllm_args, self.rank) |
341 | | - worker.init_worker(all_kwargs) |
342 | | - worker.init_device() |
343 | | - worker.load_model() |
344 | | - return worker |
| 474 | + |
| 475 | + logger.info( |
| 476 | + f"Setting up worker: rank={self.rank}, local_rank={local_rank}, " |
| 477 | + f"is_driver={is_driver_worker}, device_count={device_count}" |
| 478 | + ) |
| 479 | + |
| 480 | + try: |
| 481 | + worker = WorkerWrapperBase(self.vllm_args, self.rank) |
| 482 | + worker.init_worker(all_kwargs) |
| 483 | + worker.init_device() |
| 484 | + worker.load_model() |
| 485 | + return worker |
| 486 | + except Exception as e: |
| 487 | + logger.error(f"Failed to setup worker: {e}") |
| 488 | + logger.error( |
| 489 | + f"Worker config: rank={self.rank}, local_rank={local_rank}, " |
| 490 | + f"device_count={device_count}, world_size={parallel_config.world_size}" |
| 491 | + ) |
| 492 | + raise |
345 | 493 |
|
346 | 494 |
|
347 | 495 | def convert_input(prompt=None, prompt_token_ids=None): |
|
0 commit comments