1717import torch
1818import torchstore as ts
1919from monarch .actor import current_rank , endpoint , ProcMesh
20- from torchstore .state_dict_utils import DELIM
2120from vllm .config import VllmConfig
2221
2322from vllm .engine .arg_utils import EngineArgs
4039from vllm .v1 .structured_output import StructuredOutputManager
4140from vllm .worker .worker_base import WorkerWrapperBase
4241
43- from forge .controller import ForgeActor , get_proc_mesh , stop_proc_mesh
42+ from forge .actors .torchstore_utils import (
43+ extract_param_name ,
44+ get_param_key ,
45+ get_param_prefix ,
46+ )
4447
48+ from forge .controller import ForgeActor , get_proc_mesh , stop_proc_mesh
4549from forge .data .sharding import VLLMSharding
4650from forge .interfaces import Policy as PolicyInterface
4751from forge .types import ProcessConfig
52+ from forge .util .async_utils import make_sync_generator
4853
4954
5055@dataclass
@@ -364,16 +369,16 @@ async def run(self):
364369 fut .set_result (request_output )
365370
366371 @endpoint
367- async def update_weights (self ):
372+ async def update_weights (self , policy_version : int ):
368373 # TODO: If generating long sequences, this might be long and will block policy weight updates
369374 curr_requests = [fut for _ , fut in self .requests .values ()]
370375 if curr_requests :
371376 self .logger .debug (f"Waiting for { len (curr_requests )} pending requests" )
372377 await asyncio .gather (* curr_requests )
373378
374379 self .logger .debug (f"Starting weight update on { self .__class__ .__name__ } " )
375- await self .policy_worker .update .call (version = self . weights_version )
376- self .weights_version += 1
380+ await self .policy_worker .update .call (version = policy_version )
381+ self .weights_version = policy_version
377382 self .logger .info (f"Weight update completed (now v{ self .weights_version } )" )
378383
379384 @endpoint
@@ -395,7 +400,6 @@ async def stop(self):
395400@dataclass
396401class PolicyWorker (ForgeActor ):
397402 vllm_config : VllmConfig
398- state_dict_key : str = "model_state_dict"
399403
400404 @endpoint
401405 async def setup (self ):
@@ -407,41 +411,26 @@ async def setup(self):
407411 async def execute_model (self , schedule : SchedulerOutput ):
408412 return self .worker .execute_model (schedule )
409413
410- async def _load_tensor_parallel_state_dict (
411- self , current_state_dict : dict , version : int
412- ):
413- """
414- Load full state dict from torchstore into tensor parallel model with deterministic sharding.
415- """
416- sharding = VLLMSharding (
417- self .vllm_config .parallel_config .tensor_parallel_size , self .rank
418- )
419-
420- for param_name in current_state_dict .keys ():
421- current_tensor = current_state_dict [param_name ]
422-
423- # Load the full tensor from torchstore
424- # TODO: only get the part of the tensor that is needed
425- stored_tensor = await ts .get (
426- f"{ self .state_dict_key } { DELIM } { version } { DELIM } { param_name } "
427- )
428- sharding .load_from_source_to_target (
429- param_name ,
430- stored_tensor ,
431- current_tensor ,
432- )
433-
434414 @endpoint
435415 async def update (self , version : int ):
436416 """Update model weights by reading state dict from torchstore"""
437- key = f"{ self .state_dict_key } { DELIM } { version } "
438417 model = self .worker .model_runner .model
439- current_state_dict = model .state_dict ()
440- start = time .time ()
441- await self ._load_tensor_parallel_state_dict (current_state_dict , version )
442- self .logger .debug (
443- f"Loaded state dict from { key } in { time .time () - start } seconds"
444- )
418+ prefix = get_param_prefix (version )
419+ self .logger .debug (f"{ prefix = } " )
420+ matching_keys = await ts .keys (prefix )
421+ self .logger .debug (f"{ matching_keys = } " )
422+ # TODO: find a way to save the original huggingface parameter names.
423+ hf_names = [extract_param_name (key ) for key in matching_keys ]
424+ self .logger .debug (f"{ hf_names = } " )
425+ loaded_weights = set ()
426+ # We can't pass a generator since vllm load_weights is not async.
427+ # Instead, we just call load_weights with one parameter at a time.
428+ for name in hf_names :
429+ param = await ts .get (get_param_key (version , name ))
430+ loaded = model .load_weights ([(name , param )])
431+ del param
432+ loaded_weights .update (loaded )
433+ self .logger .info (f"Updated { len (loaded_weights )} parameters" )
445434
446435 @endpoint
447436 async def setup_kv_cache (self ):
0 commit comments