Skip to content

Commit 090b6ec

Browse files
committed
spawn in setup
1 parent 002d68b commit 090b6ec

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

src/forge/actors/generator.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import torch
2020
import torchstore as ts
21-
from monarch.actor import current_rank, endpoint, ProcMesh
21+
from monarch.actor import current_rank, endpoint, ProcMesh, this_host
2222

2323
from vllm.config import VllmConfig
2424

@@ -174,8 +174,6 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
174174
"vllm_worker", GeneratorWorker, vllm_config=vllm_config, use_dcp=use_dcp
175175
)
176176

177-
weight_fetchers = worker_procs.spawn("weight_fetcher", _WeightFetcher)
178-
179177
if isinstance(sampling_params, Mapping):
180178
sampling_params = SamplingParams.from_optional(**sampling_params)
181179
sampling_params.output_kind = RequestOutputKind.FINAL_ONLY
@@ -256,6 +254,9 @@ async def setup(self):
256254
log_stats=None,
257255
)
258256
self._start_processing()
257+
fetcher_procs = this_host().spawn_procs(per_host={"procs": 8})
258+
self._fetcher_procs = fetcher_procs
259+
self.weight_fetchers = fetcher_procs.spawn("weight_fetcher", _WeightFetcher)
259260

260261
def _start_processing(self):
261262
if self._run_task is None or self._run_task.done():
@@ -585,6 +586,7 @@ async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride]
585586
await actor._cleanup_shared_memory.call()
586587
await stop_proc_mesh(actor._worker_procs)
587588
await stop_proc_mesh(actor._generator_proc)
589+
await stop_proc_mesh(actor._fetcher_procs)
588590

589591
@endpoint
590592
async def _test_save_model_params(self):

0 commit comments

Comments
 (0)