-
Notifications
You must be signed in to change notification settings - Fork 16
Rename _test
method in generator.py
#421
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
src/forge/actors/generator.py
Outdated
logger.info("[Generator] save model parameters for testing.") | ||
await self.generator_worker._test_save_model_params.call() | ||
async def save_model_params(self): | ||
"""Save model parameters before weight update, used for debugging purposes.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Can we bump the for debugging purpose to the front? It's easy to miss at the end of a line
"""Save model parameters before weight update, used for debugging purposes.""" | |
"""(For Debugging/Testing): Save model parameters before weight update.""" |
2cbf7b1
to
37e08bf
Compare
src/forge/actors/generator.py
Outdated
await self.worker._test_save_model_params.call() | ||
async def save_model_params(self): | ||
"""Used for debugging purpose. Save model parameters before weight update.""" | ||
logger.info("[Generator] save model parameters for debugging.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we remove the logging?
src/forge/actors/generator.py
Outdated
return await self.worker._test_validate_model_params.call(validate_fn) | ||
async def validate_model_params(self, validate_fn): | ||
"""Used for debugging purpose. Validate saved params using validate_fn.""" | ||
logger.info("[Generator] start validating saved model parameters.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove logging
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why remove logging? These methods are only used in tests and won't spam the logs otherwise.
async def _test_validate_model_params(self, validate_fn): | ||
"""Validate updated model params using validate_fn.""" | ||
logger.info("[GeneratorWorker] start validating model parameters.") | ||
async def validate_model_params(self, validate_fn): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually - could we expose an endpoint to get the underlying model? Then these methods could be entirely deleted, no?
Sorry if I'm misunderstanding. But I don't like having a debug param saved on this model :/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's a fair tradeoff to expose certain method for debugging purpose.
The beauty of exposing a validator function here is it's a very flexible and reliable way to poke at the property on each worker shard and saves the cost of message SerDer.
Another restriction here is Policy is behind Service interface....
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
saves the cost of message SerDer.
Do we imagine testing w/ these methods at a substantial scale?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
saves the cost of message SerDer.
Do we imagine testing w/ these methods at a substantial scale?
My hunch is so long as we're still iterating on weight sync strategy, having some capability to poke into the states across multiple workers under multi-node, multi-replica run might be super valuable.
async def _test_save_model_params(self): | ||
"""Save model parameters before weight update, used for tesing purposes only.""" | ||
logger.info("[GeneratorWorker] save model parameters for testing.") | ||
async def save_model_params(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this method return all the model params? Instead of saving them to the model?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I am late to the party. The short answer is no: the workers are remote and it's not feasible to send all the parameters through RPC.
37e08bf
to
39695a4
Compare
Trying to make the
Generator
API a bit cleaner.