|
18 | 18 |
|
19 | 19 | import torch |
20 | 20 | import torchstore as ts |
21 | | -from monarch.actor import current_rank, endpoint, ProcMesh |
| 21 | +from monarch.actor import current_rank, endpoint, ProcMesh, this_host |
22 | 22 |
|
23 | 23 | from vllm.config import VllmConfig |
24 | 24 |
|
@@ -174,8 +174,6 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] |
174 | 174 | "vllm_worker", GeneratorWorker, vllm_config=vllm_config, use_dcp=use_dcp |
175 | 175 | ) |
176 | 176 |
|
177 | | - weight_fetchers = worker_procs.spawn("weight_fetcher", _WeightFetcher) |
178 | | - |
179 | 177 | if isinstance(sampling_params, Mapping): |
180 | 178 | sampling_params = SamplingParams.from_optional(**sampling_params) |
181 | 179 | sampling_params.output_kind = RequestOutputKind.FINAL_ONLY |
@@ -256,6 +254,9 @@ async def setup(self): |
256 | 254 | log_stats=None, |
257 | 255 | ) |
258 | 256 | self._start_processing() |
| 257 | + fetcher_procs = this_host().spawn_procs(per_host={"procs": 8}) |
| 258 | + self._fetcher_procs = fetcher_procs |
| 259 | + self.weight_fetchers = fetcher_procs.spawn("weight_fetcher", _WeightFetcher) |
259 | 260 |
|
260 | 261 | def _start_processing(self): |
261 | 262 | if self._run_task is None or self._run_task.done(): |
@@ -585,6 +586,7 @@ async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride] |
585 | 586 | await actor._cleanup_shared_memory.call() |
586 | 587 | await stop_proc_mesh(actor._worker_procs) |
587 | 588 | await stop_proc_mesh(actor._generator_proc) |
| 589 | + await stop_proc_mesh(actor._fetcher_procs) |
588 | 590 |
|
589 | 591 | @endpoint |
590 | 592 | async def _test_save_model_params(self): |
|
0 commit comments