Skip to content

Commit 39695a4

Browse files
committed
address comments
1 parent 47695ef commit 39695a4

File tree

1 file changed

+4
-9
lines changed

1 file changed

+4
-9
lines changed

src/forge/actors/generator.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -494,13 +494,11 @@ async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride]
494494
@endpoint
495495
async def save_model_params(self):
496496
"""Used for debugging purpose. Save model parameters before weight update."""
497-
logger.info("[Generator] save model parameters for debugging.")
498497
await self.worker.save_model_params.call()
499498

500499
@endpoint
501500
async def validate_model_params(self, validate_fn):
502501
"""Used for debugging purpose. Validate saved params using validate_fn."""
503-
logger.info("[Generator] start validating saved model parameters.")
504502
return await self.worker.validate_model_params.call(validate_fn)
505503

506504

@@ -514,8 +512,6 @@ class GeneratorWorker(ForgeActor):
514512
"""
515513

516514
vllm_config: VllmConfig
517-
# Used for debugging purposes only
518-
debug_saved_params = {}
519515

520516
@endpoint
521517
async def setup(self):
@@ -607,18 +603,17 @@ async def update_weights(self, version: int) -> None:
607603
@endpoint
608604
async def save_model_params(self):
609605
"""Used for debugging purposes. Save model parameters before weight update."""
610-
logger.info("[GeneratorWorker] save model parameters for debugging.")
606+
self._debug_saved_params = {}
611607
for name, param in self.worker.model_runner.model.named_parameters():
612-
self.debug_saved_params[name] = param.detach().cpu()
608+
self._debug_saved_params[name] = param.detach().cpu()
613609
logger.info(
614610
"[GeneratorWorker] finished saving model parameters, len = %d",
615-
len(self.debug_saved_params),
611+
len(self._debug_saved_params),
616612
)
617613

618614
@endpoint
619615
async def validate_model_params(self, validate_fn):
620616
"""Used for debugging purposes. Validate saved params using validate_fn."""
621-
logger.info("[GeneratorWorker] start validating saved model parameters.")
622617
return validate_fn(
623-
self.debug_saved_params, self.worker.model_runner.model, logger
618+
self._debug_saved_params, self.worker.model_runner.model, logger
624619
)

0 commit comments

Comments
 (0)