Skip to content

Commit fd33e3a

Browse files
authored
Rename _test method in generator.py (#421)
1 parent bcd86f0 commit fd33e3a

File tree

2 files changed

+18
-23
lines changed

2 files changed

+18
-23
lines changed

src/forge/actors/generator.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -492,16 +492,14 @@ async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride]
492492
await stop_proc_mesh(actor._generator_proc)
493493

494494
@endpoint
495-
async def _test_save_model_params(self):
496-
"""Save model parameters before weight update, used for tesing purposes only."""
497-
logger.info("[Generator] save model parameters for testing.")
498-
await self.worker._test_save_model_params.call()
495+
async def save_model_params(self):
496+
"""Used for debugging purpose. Save model parameters before weight update."""
497+
await self.worker.save_model_params.call()
499498

500499
@endpoint
501-
async def _test_validate_model_params(self, validate_fn):
502-
"""Validate updated model params using validate_fn."""
503-
logger.info("[Generator] start validating model parameters.")
504-
return await self.worker._test_validate_model_params.call(validate_fn)
500+
async def validate_model_params(self, validate_fn):
501+
"""Used for debugging purpose. Validate saved params using validate_fn."""
502+
return await self.worker.validate_model_params.call(validate_fn)
505503

506504

507505
@dataclass
@@ -514,8 +512,6 @@ class GeneratorWorker(ForgeActor):
514512
"""
515513

516514
vllm_config: VllmConfig
517-
# TODO: Remove below param
518-
_test_prev_params = {}
519515

520516
@endpoint
521517
async def setup(self):
@@ -605,20 +601,19 @@ async def update_weights(self, version: int) -> None:
605601
t.stop()
606602

607603
@endpoint
608-
async def _test_save_model_params(self):
609-
"""Save model parameters before weight update, used for tesing purposes only."""
610-
logger.info("[GeneratorWorker] save model parameters for testing.")
604+
async def save_model_params(self):
605+
"""Used for debugging purposes. Save model parameters before weight update."""
606+
self._debug_saved_params = {}
611607
for name, param in self.worker.model_runner.model.named_parameters():
612-
self._test_prev_params[name] = param.detach().cpu()
608+
self._debug_saved_params[name] = param.detach().cpu()
613609
logger.info(
614610
"[GeneratorWorker] finished saving model parameters, len = %d",
615-
len(self._test_prev_params),
611+
len(self._debug_saved_params),
616612
)
617613

618614
@endpoint
619-
async def _test_validate_model_params(self, validate_fn):
620-
"""Validate updated model params using validate_fn."""
621-
logger.info("[GeneratorWorker] start validating model parameters.")
615+
async def validate_model_params(self, validate_fn):
616+
"""Used for debugging purposes. Validate saved params using validate_fn."""
622617
return validate_fn(
623-
self._test_prev_params, self.worker.model_runner.model, logger
618+
self._debug_saved_params, self.worker.model_runner.model, logger
624619
)

tests/integration_tests/test_policy_update.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -254,18 +254,18 @@ async def test_sanity_check(self, _setup_and_teardown):
254254
# Setting everything to zero
255255
await rl_trainer.zero_out_model_states.call()
256256
await rl_trainer.push_weights.call(policy_version=v1)
257-
await policy._test_save_model_params.fanout()
257+
await policy.save_model_params.fanout()
258258

259259
# Sanity check that before update all the tests pass
260-
all_errs = await policy._test_validate_model_params.fanout(
260+
all_errs = await policy.validate_model_params.fanout(
261261
_test_validate_params_unchanged
262262
)
263263
for errs in all_errs:
264264
for _, e in errs.items():
265265
assert not e, f"Validation failed with exception: {e}"
266266

267267
await policy.update_weights.fanout(version=v1)
268-
all_errs = await policy._test_validate_model_params.fanout(
268+
all_errs = await policy.validate_model_params.fanout(
269269
_test_validate_params_all_zeros
270270
)
271271
for errs in all_errs:
@@ -274,7 +274,7 @@ async def test_sanity_check(self, _setup_and_teardown):
274274

275275
# Reloading v0, getting back original weights
276276
await policy.update_weights.fanout(version=v0)
277-
all_errs = await policy._test_validate_model_params.fanout(
277+
all_errs = await policy.validate_model_params.fanout(
278278
_test_validate_params_unchanged
279279
)
280280
for errs in all_errs:

0 commit comments

Comments
 (0)