@@ -103,10 +103,12 @@ class Generator(GeneratorInterface):
103103 TORCHSTORE_USE_RDMA .get_value () == 0
104104 ) # torchstore currently only accepts 0 or 1
105105 # Remaining variables are initialized in self.setup()
106+ prefetch_weights : bool = False
107+ n_fetcher_procs : int = 8
106108 lora_request : LoRARequest | None = None
107109 tokenization_kwargs : dict = field (default_factory = dict )
108110 generator_worker : GeneratorWorker | None = None
109- weight_fetchers : _WeightFetcher = None
111+ weight_fetchers : _WeightFetcher | None = None
110112
111113 def __post_init__ (self ):
112114 super ().__init__ ()
@@ -213,11 +215,6 @@ async def setup(self):
213215 self .request_lock = asyncio .Condition () # Guard for accepting_requests
214216 self .update_lock = asyncio .Condition () # Guard for updating requests
215217
216- # Shared memory allocated for weight updates
217- self .cached_state_dict_allocs : queue .Queue [
218- dict [str , SharedTensorHandle ]
219- ] = queue .Queue (maxsize = 2 )
220-
221218 vllm_config : VllmConfig = self .engine_args .create_engine_config (
222219 UsageContext .LLM_CLASS
223220 )
@@ -253,7 +250,9 @@ async def setup(self):
253250 log_stats = None ,
254251 )
255252 self ._start_processing ()
256- fetcher_procs = this_host ().spawn_procs (per_host = {"procs" : 1 })
253+ fetcher_procs = this_host ().spawn_procs (
254+ per_host = {"procs" : self .n_fetcher_procs }
255+ )
257256 self ._fetcher_procs = fetcher_procs
258257 self .weight_fetchers = fetcher_procs .spawn ("weight_fetcher" , _WeightFetcher )
259258
@@ -465,8 +464,12 @@ async def update_weights(self, version: int) -> None:
465464 >>> await trainer.push_weights()
466465 >>> generator.update_weights(version)
467466 """
468- logger .info (f"[Generator] Fetching weights for v{ version } to shared memory" )
469- fetch_fut = asyncio .create_task (self ._fetch_weights (version ))
467+ # Prefetch only if we are using RDMA
468+ if self .prefetch_weights and not self .use_dcp :
469+ logger .info (f"[Generator] Fetching weights for v{ version } to shared memory" )
470+ fetch_fut = asyncio .create_task (self ._fetch_weights (version ))
471+ else :
472+ fetch_fut = None
470473 # Serialize updates (only one update at a time)
471474 async with self .update_lock :
472475 # Grab the lock to stop accepting requests and wait on pending requests
@@ -498,10 +501,8 @@ async def update_weights(self, version: int) -> None:
498501 )
499502
500503 logger .debug (f"Starting weight update on { self .__class__ .__name__ } " )
501- if not self .use_dcp :
502- # TODO: currently the alloc in ts.get will block the event loop unfortunately
503- # potentially we need to change torchstore
504- # We have to do this because Monarch future is not directly compatible with asyncio
504+
505+ if fetch_fut is not None :
505506 t = Tracer ("generator_perf/waiting_for_fetch_weights" )
506507 t .start ()
507508 fetched_weights = await fetch_fut
@@ -700,16 +701,17 @@ async def update_weights(
700701 logger .info (f"[PolicyWorker] updated { len (loaded_weights )} paremeters" )
701702 t .stop ()
702703 return
704+ # normal update_weights without shared memory prefetching
703705 if version is None :
704706 raise ValueError (
705707 "version must be provided if not using shared_memory_state_dict"
706708 )
707- # If shared memory is not provided, we assume we are using DCP
709+ logger . info ( "[PolicyWorker] update weights from torchstore." )
708710 prefix = get_param_prefix (version )
709711 matching_keys = await ts .keys (prefix )
710712 dcp_whole_state_dict_key = get_dcp_whole_state_dict_key (version )
711713 loaded_weights = set ()
712- t = Tracer ("generator_worker_perf/update_weights " , timer = "gpu" )
714+ t = Tracer ("generator_worker_perf/update_weights_from_torchstore " , timer = "gpu" )
713715 t .start ()
714716 # Entire state dict is stored in a single DCP handle
715717 if dcp_whole_state_dict_key in matching_keys :
@@ -721,7 +723,11 @@ async def update_weights(
721723 del param
722724 loaded_weights .update (loaded )
723725 else :
724- raise RuntimeError ("No DCP handle found for the given version" )
726+ for key in matching_keys :
727+ param = await ts .get (key )
728+ loaded = model .load_weights ([(key , param )])
729+ del param
730+ loaded_weights .update (loaded )
725731 t .stop ()
726732
727733 @endpoint
0 commit comments