4646 get_param_key ,
4747 get_param_prefix ,
4848 load_tensor_from_dcp ,
49+ rdma_available ,
4950)
5051
5152from forge .controller import (
5657)
5758from forge .data_models .completion import Completion
5859from forge .data_models .prompt import to_prompt
59- from forge .env import TORCHSTORE_USE_RDMA
6060from forge .observability .metrics import record_metric , Reduce
6161from forge .observability .perf_tracker import Tracer
6262from forge .types import ProcessConfig
@@ -112,7 +112,7 @@ def __post_init__(self):
112112 self .sampling_params .output_kind = RequestOutputKind .FINAL_ONLY
113113
114114 if self .use_dcp_for_weight_sync is None :
115- self .use_dcp_for_weight_sync = not TORCHSTORE_USE_RDMA . get_value ()
115+ self .use_dcp_for_weight_sync = not rdma_available ()
116116 logger .debug (f"{ self .use_dcp_for_weight_sync = } " )
117117
118118 @endpoint
@@ -492,14 +492,16 @@ async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride]
492492 await stop_proc_mesh (actor ._generator_proc )
493493
494494 @endpoint
495- async def save_model_params (self ):
496- """Used for debugging purpose. Save model parameters before weight update."""
497- await self .worker .save_model_params .call ()
495+ async def _test_save_model_params (self ):
496+ """Save model parameters before weight update, used for tesing purposes only."""
497+ logger .info ("[Generator] save model parameters for testing." )
498+ await self .worker ._test_save_model_params .call ()
498499
499500 @endpoint
500- async def validate_model_params (self , validate_fn ):
501- """Used for debugging purpose. Validate saved params using validate_fn."""
502- return await self .worker .validate_model_params .call (validate_fn )
501+ async def _test_validate_model_params (self , validate_fn ):
502+ """Validate updated model params using validate_fn."""
503+ logger .info ("[Generator] start validating model parameters." )
504+ return await self .worker ._test_validate_model_params .call (validate_fn )
503505
504506
505507@dataclass
@@ -512,6 +514,8 @@ class GeneratorWorker(ForgeActor):
512514 """
513515
514516 vllm_config : VllmConfig
517+ # TODO: Remove below param
518+ _test_prev_params = {}
515519
516520 @endpoint
517521 async def setup (self ):
@@ -601,19 +605,20 @@ async def update_weights(self, version: int) -> None:
601605 t .stop ()
602606
603607 @endpoint
604- async def save_model_params (self ):
605- """Used for debugging purposes. Save model parameters before weight update."""
606- self . _debug_saved_params = {}
608+ async def _test_save_model_params (self ):
609+ """Save model parameters before weight update, used for tesing purposes only ."""
610+ logger . info ( "[GeneratorWorker] save model parameters for testing." )
607611 for name , param in self .worker .model_runner .model .named_parameters ():
608- self ._debug_saved_params [name ] = param .detach ().cpu ()
612+ self ._test_prev_params [name ] = param .detach ().cpu ()
609613 logger .info (
610614 "[GeneratorWorker] finished saving model parameters, len = %d" ,
611- len (self ._debug_saved_params ),
615+ len (self ._test_prev_params ),
612616 )
613617
614618 @endpoint
615- async def validate_model_params (self , validate_fn ):
616- """Used for debugging purposes. Validate saved params using validate_fn."""
619+ async def _test_validate_model_params (self , validate_fn ):
620+ """Validate updated model params using validate_fn."""
621+ logger .info ("[GeneratorWorker] start validating model parameters." )
617622 return validate_fn (
618- self ._debug_saved_params , self .worker .model_runner .model , logger
623+ self ._test_prev_params , self .worker .model_runner .model , logger
619624 )
0 commit comments