-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Added support for pytree types that inherit from tuple and typing.namedtuple #2845
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
python/mlx/utils.py
Outdated
| subtrees = [ | ||
| tree_map(fn, child, *(r[i] for r in rest), is_leaf=is_leaf) | ||
| for i, child in enumerate(tree) | ||
| ) | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this can be a generator, no need to actually expand a new list here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, sent the changes.
awni
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great. I think this is basically good to go. Can you address the comment then I can run the CI?
|
Thanks! It also needs to be rebased to resolve the conflict. If you can do it that'd be great, o/w I'll get to it later. |
|
I had some issues moving between different setups with the pre-commit hooks but its all finalized now, however this made me think: |
|
Thanks for the contribution!
I haven't seen any demand for this.. but if it becomes a common request we can consider it. |
Proposed changes
I've expanded support for tuple-related datatypes in the transforms and python tree utils.
The reason for this is that there seems to be a hole in the types available that work by default in mlx, at least in how I interact with these types of libraries. While having access to dicts, lists, and tuples is great, if I want to define a custom structure for my data and for it to work natively with the transforms, the closest thing available that has clearly labelled parts is dicts but dicts use more memory overhead, errors with keys can be quieter, and don't have the nice immutability property that namedtuples do.
The changes I've made add this functionality in a very lightweight way, and it doesn't change any previous workflow at all. Additionally, with the changes I also added the ability for types that inherit from tuples to be returned through transformations as it fit perfectly within this new system. Tests for everything were added to demonstrate all these new functionalities.
This helps match the documentation closely as the changes which are in the
tree_mapandtree_visit_updatefunctions still only focus on default python dict, list and tuple while adding greater support for objects that inherit from tuple now. The docs state:Functions in this module that return python trees will be using the default python dict, list and tuple but they can usually process objects that inherit from any of these.which I believe fits with the changes made.If needed,
typing.namedtuplecan be added to the list above.As a final benefit, this also helps people migrating from jax as namedtuples are supported by default there.
However, this is my first PR to this repo, if there any glaring mistakes or parts I've overlooked, or if this does not fit within what mlx wants to do, let me know.
Checklist
Put an
xin the boxes that apply.pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes