diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index e04bed5a8..0dc385cc0 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -492,16 +492,14 @@ async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride] await stop_proc_mesh(actor._generator_proc) @endpoint - async def _test_save_model_params(self): - """Save model parameters before weight update, used for tesing purposes only.""" - logger.info("[Generator] save model parameters for testing.") - await self.worker._test_save_model_params.call() + async def save_model_params(self): + """Used for debugging purpose. Save model parameters before weight update.""" + await self.worker.save_model_params.call() @endpoint - async def _test_validate_model_params(self, validate_fn): - """Validate updated model params using validate_fn.""" - logger.info("[Generator] start validating model parameters.") - return await self.worker._test_validate_model_params.call(validate_fn) + async def validate_model_params(self, validate_fn): + """Used for debugging purpose. Validate saved params using validate_fn.""" + return await self.worker.validate_model_params.call(validate_fn) @dataclass @@ -514,8 +512,6 @@ class GeneratorWorker(ForgeActor): """ vllm_config: VllmConfig - # TODO: Remove below param - _test_prev_params = {} @endpoint async def setup(self): @@ -605,20 +601,19 @@ async def update_weights(self, version: int) -> None: t.stop() @endpoint - async def _test_save_model_params(self): - """Save model parameters before weight update, used for tesing purposes only.""" - logger.info("[GeneratorWorker] save model parameters for testing.") + async def save_model_params(self): + """Used for debugging purposes. Save model parameters before weight update.""" + self._debug_saved_params = {} for name, param in self.worker.model_runner.model.named_parameters(): - self._test_prev_params[name] = param.detach().cpu() + self._debug_saved_params[name] = param.detach().cpu() logger.info( "[GeneratorWorker] finished saving model parameters, len = %d", - len(self._test_prev_params), + len(self._debug_saved_params), ) @endpoint - async def _test_validate_model_params(self, validate_fn): - """Validate updated model params using validate_fn.""" - logger.info("[GeneratorWorker] start validating model parameters.") + async def validate_model_params(self, validate_fn): + """Used for debugging purposes. Validate saved params using validate_fn.""" return validate_fn( - self._test_prev_params, self.worker.model_runner.model, logger + self._debug_saved_params, self.worker.model_runner.model, logger ) diff --git a/tests/integration_tests/test_policy_update.py b/tests/integration_tests/test_policy_update.py index 10b2852b7..0b99e75a2 100644 --- a/tests/integration_tests/test_policy_update.py +++ b/tests/integration_tests/test_policy_update.py @@ -254,10 +254,10 @@ async def test_sanity_check(self, _setup_and_teardown): # Setting everything to zero await rl_trainer.zero_out_model_states.call() await rl_trainer.push_weights.call(policy_version=v1) - await policy._test_save_model_params.fanout() + await policy.save_model_params.fanout() # Sanity check that before update all the tests pass - all_errs = await policy._test_validate_model_params.fanout( + all_errs = await policy.validate_model_params.fanout( _test_validate_params_unchanged ) for errs in all_errs: @@ -265,7 +265,7 @@ async def test_sanity_check(self, _setup_and_teardown): assert not e, f"Validation failed with exception: {e}" await policy.update_weights.fanout(version=v1) - all_errs = await policy._test_validate_model_params.fanout( + all_errs = await policy.validate_model_params.fanout( _test_validate_params_all_zeros ) for errs in all_errs: @@ -274,7 +274,7 @@ async def test_sanity_check(self, _setup_and_teardown): # Reloading v0, getting back original weights await policy.update_weights.fanout(version=v0) - all_errs = await policy._test_validate_model_params.fanout( + all_errs = await policy.validate_model_params.fanout( _test_validate_params_unchanged ) for errs in all_errs: