Skip to content

Commit 2d14c40

Browse files
committed
cleanup, guard with flag
1 parent fab7cd4 commit 2d14c40

File tree

1 file changed

+22
-16
lines changed

1 file changed

+22
-16
lines changed

src/forge/actors/generator.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,12 @@ class Generator(GeneratorInterface):
103103
TORCHSTORE_USE_RDMA.get_value() == 0
104104
) # torchstore currently only accepts 0 or 1
105105
# Remaining variables are initialized in self.setup()
106+
prefetch_weights: bool = False
107+
n_fetcher_procs: int = 8
106108
lora_request: LoRARequest | None = None
107109
tokenization_kwargs: dict = field(default_factory=dict)
108110
generator_worker: GeneratorWorker | None = None
109-
weight_fetchers: _WeightFetcher = None
111+
weight_fetchers: _WeightFetcher | None = None
110112

111113
def __post_init__(self):
112114
super().__init__()
@@ -213,11 +215,6 @@ async def setup(self):
213215
self.request_lock = asyncio.Condition() # Guard for accepting_requests
214216
self.update_lock = asyncio.Condition() # Guard for updating requests
215217

216-
# Shared memory allocated for weight updates
217-
self.cached_state_dict_allocs: queue.Queue[
218-
dict[str, SharedTensorHandle]
219-
] = queue.Queue(maxsize=2)
220-
221218
vllm_config: VllmConfig = self.engine_args.create_engine_config(
222219
UsageContext.LLM_CLASS
223220
)
@@ -253,7 +250,9 @@ async def setup(self):
253250
log_stats=None,
254251
)
255252
self._start_processing()
256-
fetcher_procs = this_host().spawn_procs(per_host={"procs": 1})
253+
fetcher_procs = this_host().spawn_procs(
254+
per_host={"procs": self.n_fetcher_procs}
255+
)
257256
self._fetcher_procs = fetcher_procs
258257
self.weight_fetchers = fetcher_procs.spawn("weight_fetcher", _WeightFetcher)
259258

@@ -465,8 +464,12 @@ async def update_weights(self, version: int) -> None:
465464
>>> await trainer.push_weights()
466465
>>> generator.update_weights(version)
467466
"""
468-
logger.info(f"[Generator] Fetching weights for v{version} to shared memory")
469-
fetch_fut = asyncio.create_task(self._fetch_weights(version))
467+
# Prefetch only if we are using RDMA
468+
if self.prefetch_weights and not self.use_dcp:
469+
logger.info(f"[Generator] Fetching weights for v{version} to shared memory")
470+
fetch_fut = asyncio.create_task(self._fetch_weights(version))
471+
else:
472+
fetch_fut = None
470473
# Serialize updates (only one update at a time)
471474
async with self.update_lock:
472475
# Grab the lock to stop accepting requests and wait on pending requests
@@ -498,10 +501,8 @@ async def update_weights(self, version: int) -> None:
498501
)
499502

500503
logger.debug(f"Starting weight update on {self.__class__.__name__}")
501-
if not self.use_dcp:
502-
# TODO: currently the alloc in ts.get will block the event loop unfortunately
503-
# potentially we need to change torchstore
504-
# We have to do this because Monarch future is not directly compatible with asyncio
504+
505+
if fetch_fut is not None:
505506
t = Tracer("generator_perf/waiting_for_fetch_weights")
506507
t.start()
507508
fetched_weights = await fetch_fut
@@ -700,16 +701,17 @@ async def update_weights(
700701
logger.info(f"[PolicyWorker] updated {len(loaded_weights)} paremeters")
701702
t.stop()
702703
return
704+
# normal update_weights without shared memory prefetching
703705
if version is None:
704706
raise ValueError(
705707
"version must be provided if not using shared_memory_state_dict"
706708
)
707-
# If shared memory is not provided, we assume we are using DCP
709+
logger.info("[PolicyWorker] update weights from torchstore.")
708710
prefix = get_param_prefix(version)
709711
matching_keys = await ts.keys(prefix)
710712
dcp_whole_state_dict_key = get_dcp_whole_state_dict_key(version)
711713
loaded_weights = set()
712-
t = Tracer("generator_worker_perf/update_weights", timer="gpu")
714+
t = Tracer("generator_worker_perf/update_weights_from_torchstore", timer="gpu")
713715
t.start()
714716
# Entire state dict is stored in a single DCP handle
715717
if dcp_whole_state_dict_key in matching_keys:
@@ -721,7 +723,11 @@ async def update_weights(
721723
del param
722724
loaded_weights.update(loaded)
723725
else:
724-
raise RuntimeError("No DCP handle found for the given version")
726+
for key in matching_keys:
727+
param = await ts.get(key)
728+
loaded = model.load_weights([(key, param)])
729+
del param
730+
loaded_weights.update(loaded)
725731
t.stop()
726732

727733
@endpoint

0 commit comments

Comments
 (0)