Skip to content

Commit 4a63c05

Browse files
committed
as title
1 parent bcd86f0 commit 4a63c05

File tree

2 files changed

+24
-23
lines changed

2 files changed

+24
-23
lines changed

src/forge/actors/generator.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -492,16 +492,17 @@ async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride]
492492
await stop_proc_mesh(actor._generator_proc)
493493

494494
@endpoint
495-
async def _test_save_model_params(self):
496-
"""Save model parameters before weight update, used for tesing purposes only."""
497-
logger.info("[Generator] save model parameters for testing.")
498-
await self.worker._test_save_model_params.call()
495+
496+
async def save_model_params(self):
497+
"""Used for debugging purpose. Save model parameters before weight update."""
498+
logger.info("[Generator] save model parameters for debugging.")
499+
await self.workerr.save_model_params.call()
499500

500501
@endpoint
501-
async def _test_validate_model_params(self, validate_fn):
502-
"""Validate updated model params using validate_fn."""
503-
logger.info("[Generator] start validating model parameters.")
504-
return await self.worker._test_validate_model_params.call(validate_fn)
502+
async def validate_model_params(self, validate_fn):
503+
"""Used for debugging purpose. Validate saved params using validate_fn."""
504+
logger.info("[Generator] start validating saved model parameters.")
505+
return await self.worker.validate_model_params.call(validate_fn)
505506

506507

507508
@dataclass
@@ -514,8 +515,8 @@ class GeneratorWorker(ForgeActor):
514515
"""
515516

516517
vllm_config: VllmConfig
517-
# TODO: Remove below param
518-
_test_prev_params = {}
518+
# Used for debugging purposes only
519+
debug_saved_params = {}
519520

520521
@endpoint
521522
async def setup(self):
@@ -605,20 +606,20 @@ async def update_weights(self, version: int) -> None:
605606
t.stop()
606607

607608
@endpoint
608-
async def _test_save_model_params(self):
609-
"""Save model parameters before weight update, used for tesing purposes only."""
610-
logger.info("[GeneratorWorker] save model parameters for testing.")
609+
async def save_model_params(self):
610+
"""Used for debugging purposes. Save model parameters before weight update."""
611+
logger.info("[GeneratorWorker] save model parameters for debugging.")
611612
for name, param in self.worker.model_runner.model.named_parameters():
612-
self._test_prev_params[name] = param.detach().cpu()
613+
self.debug_saved_params[name] = param.detach().cpu()
613614
logger.info(
614615
"[GeneratorWorker] finished saving model parameters, len = %d",
615-
len(self._test_prev_params),
616+
len(self.debug_saved_params),
616617
)
617618

618619
@endpoint
619-
async def _test_validate_model_params(self, validate_fn):
620-
"""Validate updated model params using validate_fn."""
621-
logger.info("[GeneratorWorker] start validating model parameters.")
620+
async def validate_model_params(self, validate_fn):
621+
"""Used for debugging purposes. Validate saved params using validate_fn."""
622+
logger.info("[GeneratorWorker] start validating saved model parameters.")
622623
return validate_fn(
623-
self._test_prev_params, self.worker.model_runner.model, logger
624+
self.debug_saved_params, self.worker.model_runner.model, logger
624625
)

tests/integration_tests/test_policy_update.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -254,18 +254,18 @@ async def test_sanity_check(self, _setup_and_teardown):
254254
# Setting everything to zero
255255
await rl_trainer.zero_out_model_states.call()
256256
await rl_trainer.push_weights.call(policy_version=v1)
257-
await policy._test_save_model_params.fanout()
257+
await policy.save_model_params.fanout()
258258

259259
# Sanity check that before update all the tests pass
260-
all_errs = await policy._test_validate_model_params.fanout(
260+
all_errs = await policy.validate_model_params.fanout(
261261
_test_validate_params_unchanged
262262
)
263263
for errs in all_errs:
264264
for _, e in errs.items():
265265
assert not e, f"Validation failed with exception: {e}"
266266

267267
await policy.update_weights.fanout(version=v1)
268-
all_errs = await policy._test_validate_model_params.fanout(
268+
all_errs = await policy.validate_model_params.fanout(
269269
_test_validate_params_all_zeros
270270
)
271271
for errs in all_errs:
@@ -274,7 +274,7 @@ async def test_sanity_check(self, _setup_and_teardown):
274274

275275
# Reloading v0, getting back original weights
276276
await policy.update_weights.fanout(version=v0)
277-
all_errs = await policy._test_validate_model_params.fanout(
277+
all_errs = await policy.validate_model_params.fanout(
278278
_test_validate_params_unchanged
279279
)
280280
for errs in all_errs:

0 commit comments

Comments
 (0)