@@ -492,16 +492,17 @@ async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride]
492492 await stop_proc_mesh (actor ._generator_proc )
493493
494494 @endpoint
495- async def _test_save_model_params (self ):
496- """Save model parameters before weight update, used for tesing purposes only."""
497- logger .info ("[Generator] save model parameters for testing." )
498- await self .worker ._test_save_model_params .call ()
495+
496+ async def save_model_params (self ):
497+ """Used for debugging purpose. Save model parameters before weight update."""
498+ logger .info ("[Generator] save model parameters for debugging." )
499+ await self .workerr .save_model_params .call ()
499500
500501 @endpoint
501- async def _test_validate_model_params (self , validate_fn ):
502- """Validate updated model params using validate_fn."""
503- logger .info ("[Generator] start validating model parameters." )
504- return await self .worker ._test_validate_model_params .call (validate_fn )
502+ async def validate_model_params (self , validate_fn ):
503+ """Used for debugging purpose. Validate saved params using validate_fn."""
504+ logger .info ("[Generator] start validating saved model parameters." )
505+ return await self .worker .validate_model_params .call (validate_fn )
505506
506507
507508@dataclass
@@ -514,8 +515,8 @@ class GeneratorWorker(ForgeActor):
514515 """
515516
516517 vllm_config : VllmConfig
517- # TODO: Remove below param
518- _test_prev_params = {}
518+ # Used for debugging purposes only
519+ debug_saved_params = {}
519520
520521 @endpoint
521522 async def setup (self ):
@@ -605,20 +606,20 @@ async def update_weights(self, version: int) -> None:
605606 t .stop ()
606607
607608 @endpoint
608- async def _test_save_model_params (self ):
609- """Save model parameters before weight update, used for tesing purposes only ."""
610- logger .info ("[GeneratorWorker] save model parameters for testing ." )
609+ async def save_model_params (self ):
610+ """Used for debugging purposes. Save model parameters before weight update."""
611+ logger .info ("[GeneratorWorker] save model parameters for debugging ." )
611612 for name , param in self .worker .model_runner .model .named_parameters ():
612- self ._test_prev_params [name ] = param .detach ().cpu ()
613+ self .debug_saved_params [name ] = param .detach ().cpu ()
613614 logger .info (
614615 "[GeneratorWorker] finished saving model parameters, len = %d" ,
615- len (self ._test_prev_params ),
616+ len (self .debug_saved_params ),
616617 )
617618
618619 @endpoint
619- async def _test_validate_model_params (self , validate_fn ):
620- """Validate updated model params using validate_fn."""
621- logger .info ("[GeneratorWorker] start validating model parameters." )
620+ async def validate_model_params (self , validate_fn ):
621+ """Used for debugging purposes. Validate saved params using validate_fn."""
622+ logger .info ("[GeneratorWorker] start validating saved model parameters." )
622623 return validate_fn (
623- self ._test_prev_params , self .worker .model_runner .model , logger
624+ self .debug_saved_params , self .worker .model_runner .model , logger
624625 )
0 commit comments