From 4a63c0576faabc20f7850d3314339c0ec432e9bd Mon Sep 17 00:00:00 2001 From: "Jiyue (Jennifer) Wang" Date: Wed, 15 Oct 2025 12:31:31 -0400 Subject: [PATCH 1/3] as title --- src/forge/actors/generator.py | 39 ++++++++++--------- tests/integration_tests/test_policy_update.py | 8 ++-- 2 files changed, 24 insertions(+), 23 deletions(-) diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index e04bed5a8..89b710a3c 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -492,16 +492,17 @@ async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride] await stop_proc_mesh(actor._generator_proc) @endpoint - async def _test_save_model_params(self): - """Save model parameters before weight update, used for tesing purposes only.""" - logger.info("[Generator] save model parameters for testing.") - await self.worker._test_save_model_params.call() + + async def save_model_params(self): + """Used for debugging purpose. Save model parameters before weight update.""" + logger.info("[Generator] save model parameters for debugging.") + await self.workerr.save_model_params.call() @endpoint - async def _test_validate_model_params(self, validate_fn): - """Validate updated model params using validate_fn.""" - logger.info("[Generator] start validating model parameters.") - return await self.worker._test_validate_model_params.call(validate_fn) + async def validate_model_params(self, validate_fn): + """Used for debugging purpose. Validate saved params using validate_fn.""" + logger.info("[Generator] start validating saved model parameters.") + return await self.worker.validate_model_params.call(validate_fn) @dataclass @@ -514,8 +515,8 @@ class GeneratorWorker(ForgeActor): """ vllm_config: VllmConfig - # TODO: Remove below param - _test_prev_params = {} + # Used for debugging purposes only + debug_saved_params = {} @endpoint async def setup(self): @@ -605,20 +606,20 @@ async def update_weights(self, version: int) -> None: t.stop() @endpoint - async def _test_save_model_params(self): - """Save model parameters before weight update, used for tesing purposes only.""" - logger.info("[GeneratorWorker] save model parameters for testing.") + async def save_model_params(self): + """Used for debugging purposes. Save model parameters before weight update.""" + logger.info("[GeneratorWorker] save model parameters for debugging.") for name, param in self.worker.model_runner.model.named_parameters(): - self._test_prev_params[name] = param.detach().cpu() + self.debug_saved_params[name] = param.detach().cpu() logger.info( "[GeneratorWorker] finished saving model parameters, len = %d", - len(self._test_prev_params), + len(self.debug_saved_params), ) @endpoint - async def _test_validate_model_params(self, validate_fn): - """Validate updated model params using validate_fn.""" - logger.info("[GeneratorWorker] start validating model parameters.") + async def validate_model_params(self, validate_fn): + """Used for debugging purposes. Validate saved params using validate_fn.""" + logger.info("[GeneratorWorker] start validating saved model parameters.") return validate_fn( - self._test_prev_params, self.worker.model_runner.model, logger + self.debug_saved_params, self.worker.model_runner.model, logger ) diff --git a/tests/integration_tests/test_policy_update.py b/tests/integration_tests/test_policy_update.py index 10b2852b7..0b99e75a2 100644 --- a/tests/integration_tests/test_policy_update.py +++ b/tests/integration_tests/test_policy_update.py @@ -254,10 +254,10 @@ async def test_sanity_check(self, _setup_and_teardown): # Setting everything to zero await rl_trainer.zero_out_model_states.call() await rl_trainer.push_weights.call(policy_version=v1) - await policy._test_save_model_params.fanout() + await policy.save_model_params.fanout() # Sanity check that before update all the tests pass - all_errs = await policy._test_validate_model_params.fanout( + all_errs = await policy.validate_model_params.fanout( _test_validate_params_unchanged ) for errs in all_errs: @@ -265,7 +265,7 @@ async def test_sanity_check(self, _setup_and_teardown): assert not e, f"Validation failed with exception: {e}" await policy.update_weights.fanout(version=v1) - all_errs = await policy._test_validate_model_params.fanout( + all_errs = await policy.validate_model_params.fanout( _test_validate_params_all_zeros ) for errs in all_errs: @@ -274,7 +274,7 @@ async def test_sanity_check(self, _setup_and_teardown): # Reloading v0, getting back original weights await policy.update_weights.fanout(version=v0) - all_errs = await policy._test_validate_model_params.fanout( + all_errs = await policy.validate_model_params.fanout( _test_validate_params_unchanged ) for errs in all_errs: From 47695ef507f56ac6e25a5d028f89f79c8d25dc1c Mon Sep 17 00:00:00 2001 From: "Jiyue (Jennifer) Wang" Date: Wed, 15 Oct 2025 15:12:07 -0400 Subject: [PATCH 2/3] typo --- src/forge/actors/generator.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index 89b710a3c..2c9722811 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -492,11 +492,10 @@ async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride] await stop_proc_mesh(actor._generator_proc) @endpoint - async def save_model_params(self): """Used for debugging purpose. Save model parameters before weight update.""" logger.info("[Generator] save model parameters for debugging.") - await self.workerr.save_model_params.call() + await self.worker.save_model_params.call() @endpoint async def validate_model_params(self, validate_fn): From 39695a4ab3fe1b93d0677c8510108c6ceb67c727 Mon Sep 17 00:00:00 2001 From: "Jiyue (Jennifer) Wang" Date: Thu, 16 Oct 2025 12:07:29 -0400 Subject: [PATCH 3/3] address comments --- src/forge/actors/generator.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index 2c9722811..0dc385cc0 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -494,13 +494,11 @@ async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride] @endpoint async def save_model_params(self): """Used for debugging purpose. Save model parameters before weight update.""" - logger.info("[Generator] save model parameters for debugging.") await self.worker.save_model_params.call() @endpoint async def validate_model_params(self, validate_fn): """Used for debugging purpose. Validate saved params using validate_fn.""" - logger.info("[Generator] start validating saved model parameters.") return await self.worker.validate_model_params.call(validate_fn) @@ -514,8 +512,6 @@ class GeneratorWorker(ForgeActor): """ vllm_config: VllmConfig - # Used for debugging purposes only - debug_saved_params = {} @endpoint async def setup(self): @@ -607,18 +603,17 @@ async def update_weights(self, version: int) -> None: @endpoint async def save_model_params(self): """Used for debugging purposes. Save model parameters before weight update.""" - logger.info("[GeneratorWorker] save model parameters for debugging.") + self._debug_saved_params = {} for name, param in self.worker.model_runner.model.named_parameters(): - self.debug_saved_params[name] = param.detach().cpu() + self._debug_saved_params[name] = param.detach().cpu() logger.info( "[GeneratorWorker] finished saving model parameters, len = %d", - len(self.debug_saved_params), + len(self._debug_saved_params), ) @endpoint async def validate_model_params(self, validate_fn): """Used for debugging purposes. Validate saved params using validate_fn.""" - logger.info("[GeneratorWorker] start validating saved model parameters.") return validate_fn( - self.debug_saved_params, self.worker.model_runner.model, logger + self._debug_saved_params, self.worker.model_runner.model, logger )