|
13 | 13 | from collections.abc import Mapping |
14 | 14 | from copy import copy |
15 | 15 | from dataclasses import dataclass, field |
| 16 | +from typing import Optional |
16 | 17 |
|
17 | 18 | import torch |
18 | 19 | import torchstore as ts |
19 | | -from monarch.actor import current_rank, endpoint, ProcMesh |
| 20 | +from monarch.actor import current_rank, endpoint, ProcMesh, this_host |
| 21 | + |
20 | 22 | from vllm.config import VllmConfig |
21 | 23 |
|
22 | 24 | from vllm.engine.arg_utils import EngineArgs |
|
60 | 62 | from forge.observability.metrics import record_metric, Reduce |
61 | 63 | from forge.observability.perf_tracker import Tracer |
62 | 64 | from forge.types import ProcessConfig |
| 65 | +from forge.util._shared_tensor import SharedTensor, SharedTensorHandle |
63 | 66 |
|
64 | 67 | logger = logging.getLogger(__name__) |
65 | 68 | logger.setLevel(logging.INFO) |
@@ -92,6 +95,8 @@ class Generator(ForgeActor): |
92 | 95 | engine_args: EngineArgs | Mapping = field(default_factory=EngineArgs) |
93 | 96 | sampling_params: SamplingParams | Mapping = field(default_factory=SamplingParams) |
94 | 97 | use_dcp_for_weight_sync: bool | None = None |
| 98 | + prefetch_weights_to_shm: bool = True |
| 99 | + n_fetcher_procs: int = 8 |
95 | 100 |
|
96 | 101 | def __post_init__(self): |
97 | 102 | super().__init__() |
@@ -226,11 +231,61 @@ async def setup(self): |
226 | 231 | log_stats=None, |
227 | 232 | ) |
228 | 233 | self._start_processing() |
| 234 | + if self.prefetch_weights_to_shm: |
| 235 | + self._spawn_fetchers() |
| 236 | + |
| 237 | + def _spawn_fetchers(self): |
| 238 | + """Spawn weight fetchers that prefetch weights from torchstore to shared memory.""" |
| 239 | + # TODO: this assumes the generator is on the same host as the worker |
| 240 | + # and only works for single host generators. Figure out how to support |
| 241 | + # generators with workers spanned across multiple hosts. |
| 242 | + fetcher_procs = this_host().spawn_procs( |
| 243 | + per_host={"procs": self.n_fetcher_procs} |
| 244 | + ) |
| 245 | + self._fetcher_procs = fetcher_procs |
| 246 | + self.weight_fetchers = fetcher_procs.spawn("weight_fetcher", _WeightFetcher) |
229 | 247 |
|
230 | 248 | def _start_processing(self): |
231 | 249 | if self._run_task is None or self._run_task.done(): |
232 | 250 | self._run_task = asyncio.create_task(self.run()) |
233 | 251 |
|
| 252 | + async def _drop_shared_memory(self, state_dict: dict[str, SharedTensorHandle]): |
| 253 | + for handle in state_dict.values(): |
| 254 | + handle.drop() |
| 255 | + |
| 256 | + async def _fetch_weights( |
| 257 | + self, |
| 258 | + version: int, |
| 259 | + ) -> dict[str, SharedTensorHandle]: |
| 260 | + """Fetch weights from torchstore and return a dict of {name: SharedTensorHandle}.""" |
| 261 | + t = Tracer("generator_perf/_fetch_weights") |
| 262 | + t.start() |
| 263 | + prefix = get_param_prefix(version) |
| 264 | + matching_keys = await ts.keys(prefix) |
| 265 | + hf_param_names = [extract_param_name(key) for key in matching_keys] |
| 266 | + |
| 267 | + n_fetchers = self.weight_fetchers.size() |
| 268 | + |
| 269 | + def split_keys(keys): |
| 270 | + return [keys[i::n_fetchers] for i in range(n_fetchers)] |
| 271 | + |
| 272 | + futures = [] |
| 273 | + for i, names in enumerate(split_keys(hf_param_names)): |
| 274 | + fut = self.weight_fetchers.slice(procs=i).fetch.call_one( |
| 275 | + version=version, param_names=names |
| 276 | + ) |
| 277 | + futures.append(fut) |
| 278 | + |
| 279 | + sub_state_dicts = [await fut for fut in futures] |
| 280 | + |
| 281 | + state_dict = {} |
| 282 | + for sd in sub_state_dicts: |
| 283 | + state_dict.update(sd) |
| 284 | + |
| 285 | + t.stop() |
| 286 | + |
| 287 | + return state_dict |
| 288 | + |
234 | 289 | @endpoint |
235 | 290 | async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]: |
236 | 291 | """Generate a response for the given prompt |
@@ -384,6 +439,12 @@ async def update_weights(self, version: int) -> None: |
384 | 439 | >>> await trainer.push_weights() |
385 | 440 | >>> generator.update_weights(version) |
386 | 441 | """ |
| 442 | + # TODO: enable shared memory prefetch for DCP-based weight sync |
| 443 | + if self.prefetch_weights_to_shm and not self.use_dcp_for_weight_sync: |
| 444 | + logger.info(f"[Generator] Fetching weights for v{version} to shared memory") |
| 445 | + fetch_fut = asyncio.create_task(self._fetch_weights(version)) |
| 446 | + else: |
| 447 | + fetch_fut = None |
387 | 448 | # Serialize updates (only one update at a time) |
388 | 449 | async with self.update_lock: |
389 | 450 | # Grab the lock to stop accepting requests and wait on pending requests |
@@ -415,8 +476,19 @@ async def update_weights(self, version: int) -> None: |
415 | 476 | ) |
416 | 477 |
|
417 | 478 | logger.debug(f"Starting weight update on {self.__class__.__name__}") |
418 | | - # Call update_weights on every generator worker |
419 | | - await self.worker.update_weights.call(version=version) |
| 479 | + |
| 480 | + if fetch_fut is not None: |
| 481 | + t = Tracer("generator_perf/waiting_for_fetch_weights") |
| 482 | + t.start() |
| 483 | + fetched_weights = await fetch_fut |
| 484 | + t.stop() |
| 485 | + # Call update_weights on every policy_worker |
| 486 | + await self.worker.update_weights.call( |
| 487 | + shared_memory_state_dict=fetched_weights |
| 488 | + ) |
| 489 | + await self._drop_shared_memory(fetched_weights) |
| 490 | + else: |
| 491 | + await self.worker.update_weights.call(version=version) |
420 | 492 | self.generator_version = version |
421 | 493 |
|
422 | 494 | # After updating the weights, we need to reset the KV cache |
@@ -490,6 +562,7 @@ async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride] |
490 | 562 | await actor.stop.call() |
491 | 563 | await stop_proc_mesh(actor._worker_procs) |
492 | 564 | await stop_proc_mesh(actor._generator_proc) |
| 565 | + await stop_proc_mesh(actor._fetcher_procs) |
493 | 566 |
|
494 | 567 | @endpoint |
495 | 568 | async def _test_save_model_params(self): |
@@ -573,14 +646,42 @@ async def execute_model(self, schedule: SchedulerOutput) -> ModelRunnerOutput: |
573 | 646 | return self.worker.execute_model(schedule) |
574 | 647 |
|
575 | 648 | @endpoint |
576 | | - async def update_weights(self, version: int) -> None: |
| 649 | + async def update_weights( |
| 650 | + self, |
| 651 | + version: Optional[int] = None, |
| 652 | + *, |
| 653 | + shared_memory_state_dict: Optional[dict[str, SharedTensorHandle]] = None, |
| 654 | + ) -> None: |
577 | 655 | model = self.worker.model_runner.model |
| 656 | + if shared_memory_state_dict is not None: |
| 657 | + logger.info("[PolicyWorker] update weights from shared memory.") |
| 658 | + t = Tracer( |
| 659 | + "generator_worker_perf/update_weights_from_shared_memory", timer="gpu" |
| 660 | + ) |
| 661 | + t.start() |
| 662 | + loaded_weights = set() |
| 663 | + for name, param_handle in shared_memory_state_dict.items(): |
| 664 | + # Use context manager for automatic cleanup |
| 665 | + with param_handle.to_shared_tensor() as shared_tensor: |
| 666 | + param = shared_tensor.tensor |
| 667 | + loaded = model.load_weights([(name, param)]) |
| 668 | + del param |
| 669 | + loaded_weights.update(loaded) |
| 670 | + logger.info(f"[PolicyWorker] updated {len(loaded_weights)} paremeters") |
| 671 | + t.stop() |
| 672 | + return |
| 673 | + # normal update_weights without shared memory prefetching |
| 674 | + if version is None: |
| 675 | + raise ValueError( |
| 676 | + "version must be provided if not using shared_memory_state_dict" |
| 677 | + ) |
| 678 | + logger.info("[PolicyWorker] update weights from torchstore.") |
578 | 679 | prefix = get_param_prefix(version) |
579 | 680 | matching_keys = await ts.keys(prefix) |
580 | 681 | dcp_whole_state_dict_key = get_dcp_whole_state_dict_key(version) |
581 | 682 | use_dcp_for_weight_sync = dcp_whole_state_dict_key in matching_keys |
582 | 683 | loaded_weights = set() |
583 | | - t = Tracer("worker_perf/update_weights", timer="gpu") |
| 684 | + t = Tracer("generator_worker_perf/update_weights_from_torchstore", timer="gpu") |
584 | 685 | t.start() |
585 | 686 |
|
586 | 687 | if use_dcp_for_weight_sync: |
@@ -622,3 +723,27 @@ async def _test_validate_model_params(self, validate_fn): |
622 | 723 | return validate_fn( |
623 | 724 | self._test_prev_params, self.worker.model_runner.model, logger |
624 | 725 | ) |
| 726 | + |
| 727 | + |
| 728 | +class _WeightFetcher(ForgeActor): |
| 729 | + """Fetches weights from torchstore and loads them into shared memory. |
| 730 | + This has to be colocated with the GeneratorWorker.""" |
| 731 | + |
| 732 | + @endpoint |
| 733 | + async def fetch( |
| 734 | + self, |
| 735 | + *, |
| 736 | + version: int, |
| 737 | + param_names: list[str], |
| 738 | + ) -> dict[str, SharedTensorHandle]: |
| 739 | + """Fetch weights from torchstore and load them into shared memory.""" |
| 740 | + sd = {} |
| 741 | + for name in param_names: |
| 742 | + param_key = get_param_key(version, name) |
| 743 | + param = await ts.get(param_key) |
| 744 | + # Use context manager to ensure cleanup after getting handle |
| 745 | + with SharedTensor(tensor=param) as shared_tensor: |
| 746 | + handle = shared_tensor.get_handle() |
| 747 | + sd[name] = handle |
| 748 | + del param # Explicitly free the tensor after copying to shared memory |
| 749 | + return sd |
0 commit comments