@@ -492,16 +492,14 @@ async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride]
492492 await stop_proc_mesh (actor ._generator_proc )
493493
494494 @endpoint
495- async def _test_save_model_params (self ):
496- """Save model parameters before weight update, used for tesing purposes only."""
497- logger .info ("[Generator] save model parameters for testing." )
498- await self .worker ._test_save_model_params .call ()
495+ async def save_model_params (self ):
496+ """Used for debugging purpose. Save model parameters before weight update."""
497+ await self .worker .save_model_params .call ()
499498
500499 @endpoint
501- async def _test_validate_model_params (self , validate_fn ):
502- """Validate updated model params using validate_fn."""
503- logger .info ("[Generator] start validating model parameters." )
504- return await self .worker ._test_validate_model_params .call (validate_fn )
500+ async def validate_model_params (self , validate_fn ):
501+ """Used for debugging purpose. Validate saved params using validate_fn."""
502+ return await self .worker .validate_model_params .call (validate_fn )
505503
506504
507505@dataclass
@@ -514,8 +512,6 @@ class GeneratorWorker(ForgeActor):
514512 """
515513
516514 vllm_config : VllmConfig
517- # TODO: Remove below param
518- _test_prev_params = {}
519515
520516 @endpoint
521517 async def setup (self ):
@@ -605,20 +601,19 @@ async def update_weights(self, version: int) -> None:
605601 t .stop ()
606602
607603 @endpoint
608- async def _test_save_model_params (self ):
609- """Save model parameters before weight update, used for tesing purposes only ."""
610- logger . info ( "[GeneratorWorker] save model parameters for testing." )
604+ async def save_model_params (self ):
605+ """Used for debugging purposes. Save model parameters before weight update."""
606+ self . _debug_saved_params = {}
611607 for name , param in self .worker .model_runner .model .named_parameters ():
612- self ._test_prev_params [name ] = param .detach ().cpu ()
608+ self ._debug_saved_params [name ] = param .detach ().cpu ()
613609 logger .info (
614610 "[GeneratorWorker] finished saving model parameters, len = %d" ,
615- len (self ._test_prev_params ),
611+ len (self ._debug_saved_params ),
616612 )
617613
618614 @endpoint
619- async def _test_validate_model_params (self , validate_fn ):
620- """Validate updated model params using validate_fn."""
621- logger .info ("[GeneratorWorker] start validating model parameters." )
615+ async def validate_model_params (self , validate_fn ):
616+ """Used for debugging purposes. Validate saved params using validate_fn."""
622617 return validate_fn (
623- self ._test_prev_params , self .worker .model_runner .model , logger
618+ self ._debug_saved_params , self .worker .model_runner .model , logger
624619 )
0 commit comments