-
Notifications
You must be signed in to change notification settings - Fork 24
[2/N] Core trainer abstraction #158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I thought it would create a stacked diff. But seems like this is cumulative. :( let me know if there is an easy way to fix this. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm mostly on board with the trainer interface, I just have concerns about if compile and pipeline works well with these.
pass | ||
|
||
@abstractmethod | ||
def apply_gradients(self) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general this looks fine, my main concern is whether this would be compatible with Compile and Pipeline parallel APIs? @H-Huang
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have a concern if we want to expose another Step
API for trainer.
pass | ||
|
||
@abstractmethod | ||
def snapshot_weights(self) -> WeightsBuffer: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would likely push weights to store for checkpoint handling and weight sync to take over.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to update_weights in the policy, this will be somewhat dependent on the internal state of apply_gradients where you want to call it right after apply_gradients is done (without awaiting it) and then not call apply_gradients again until it has completed. Not as complex as the policy side, but something to keep in mind.
# TODO: This file needs should NOT be in the data_models folder/package | ||
|
||
|
||
class Store(ABC): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left my comment in 3/N but to repeat here, is it valuable to abstract the buffer too? It's as core to the library as Monarch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's as core to the library as Monarch.
buffer is just a wrapper on top of store, hence I did not do that. can you elaborate on your reasoning for abstracting buffer?
[EDIT]: Don't have an opinion but does not hurt to abstract the buffer too.
pass | ||
|
||
|
||
class WeightsBuffer: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't follow the reason to have this extra layer? Also a buffer is what holds some individual data, vs this would be the entire store?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At a high-level...
Store
is a generic key-value storage abstraction. It can store any kind of data (strings, tensors, configs, etc.), not just model weights.
WeightsBuffer
is a specialized abstraction focused on the logic and conventions for storing and retrieving model weights. It may add domain-specific features, validation, serialization, or metadata handling that are unique to weights.
A split version of this RFC