-
Notifications
You must be signed in to change notification settings - Fork 18
Rename _test
method in generator.py
#421
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -492,16 +492,14 @@ async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride] | |
await stop_proc_mesh(actor._generator_proc) | ||
|
||
@endpoint | ||
async def _test_save_model_params(self): | ||
"""Save model parameters before weight update, used for tesing purposes only.""" | ||
logger.info("[Generator] save model parameters for testing.") | ||
await self.worker._test_save_model_params.call() | ||
async def save_model_params(self): | ||
"""Used for debugging purpose. Save model parameters before weight update.""" | ||
await self.worker.save_model_params.call() | ||
|
||
@endpoint | ||
async def _test_validate_model_params(self, validate_fn): | ||
"""Validate updated model params using validate_fn.""" | ||
logger.info("[Generator] start validating model parameters.") | ||
return await self.worker._test_validate_model_params.call(validate_fn) | ||
async def validate_model_params(self, validate_fn): | ||
"""Used for debugging purpose. Validate saved params using validate_fn.""" | ||
return await self.worker.validate_model_params.call(validate_fn) | ||
|
||
|
||
@dataclass | ||
|
@@ -514,8 +512,6 @@ class GeneratorWorker(ForgeActor): | |
""" | ||
|
||
vllm_config: VllmConfig | ||
# TODO: Remove below param | ||
_test_prev_params = {} | ||
|
||
@endpoint | ||
async def setup(self): | ||
|
@@ -605,20 +601,19 @@ async def update_weights(self, version: int) -> None: | |
t.stop() | ||
|
||
@endpoint | ||
async def _test_save_model_params(self): | ||
"""Save model parameters before weight update, used for tesing purposes only.""" | ||
logger.info("[GeneratorWorker] save model parameters for testing.") | ||
async def save_model_params(self): | ||
"""Used for debugging purposes. Save model parameters before weight update.""" | ||
self._debug_saved_params = {} | ||
for name, param in self.worker.model_runner.model.named_parameters(): | ||
self._test_prev_params[name] = param.detach().cpu() | ||
self._debug_saved_params[name] = param.detach().cpu() | ||
logger.info( | ||
"[GeneratorWorker] finished saving model parameters, len = %d", | ||
len(self._test_prev_params), | ||
len(self._debug_saved_params), | ||
) | ||
|
||
@endpoint | ||
async def _test_validate_model_params(self, validate_fn): | ||
"""Validate updated model params using validate_fn.""" | ||
logger.info("[GeneratorWorker] start validating model parameters.") | ||
async def validate_model_params(self, validate_fn): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually - could we expose an endpoint to get the underlying model? Then these methods could be entirely deleted, no? Sorry if I'm misunderstanding. But I don't like having a debug param saved on this model :/ There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's a fair tradeoff to expose certain method for debugging purpose. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Do we imagine testing w/ these methods at a substantial scale? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
My hunch is so long as we're still iterating on weight sync strategy, having some capability to poke into the states across multiple workers under multi-node, multi-replica run might be super valuable. |
||
"""Used for debugging purposes. Validate saved params using validate_fn.""" | ||
return validate_fn( | ||
self._test_prev_params, self.worker.model_runner.model, logger | ||
self._debug_saved_params, self.worker.model_runner.model, logger | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this method return all the model params? Instead of saving them to the model?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I am late to the party. The short answer is no: the workers are remote and it's not feasible to send all the parameters through RPC.