Skip to content

Add support for MLX arrays#301

Closed
gabrieldemarmiesse wants to merge 6 commits intopatrick-kidger:mainfrom
gabrieldemarmiesse:add_mlx_support
Closed

Add support for MLX arrays#301
gabrieldemarmiesse wants to merge 6 commits intopatrick-kidger:mainfrom
gabrieldemarmiesse:add_mlx_support

Conversation

@gabrieldemarmiesse
Copy link

@gabrieldemarmiesse gabrieldemarmiesse commented Feb 21, 2025

Fixes #299

@gabrieldemarmiesse gabrieldemarmiesse changed the title Allow MLX types to be recognized Add support for MLX arrays Feb 21, 2025
@gabrieldemarmiesse gabrieldemarmiesse marked this pull request as ready for review February 21, 2025 10:50
if type(out) is tuple:
array_type, name, dtypes, dims, index_variadic, dim_str = out
# Nanobind classes can't be a base type
can_subclass = array_type is Any or "nanobind." in str(type(array_type))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was afraid things might look something like this!

FWIW we could maybe just remove subclassing altogether -- I'm not sure how much this buys us, and we're really threading the needle here to do so -- WDYT?

(I think I might have introduced this subclassing to make things work better when using type annotations with plum, although I might have that wrong. I think that's a use-case that can be supported in other ways, though.)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no opinion on the matter. I'll remove the subclassing entirely and see where it goes

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

______________________________________________________ test_subclass _______________________________________________________

    def test_subclass():
>       assert issubclass(Float[Array, ""], Array)
E       AssertionError: assert False
E        +  where False = issubclass(<class 'jaxtyping.Float[Array, '']'>, Array)

test/test_array.py:607: AssertionError

The only test failing is this one: test/test_array.py::test_subclass - AssertionError: assert False. Every assertion on this test is failing.
I find it strange that a type hint need to be a subclass of a type. If that's not needed, we can remove this completely. How should I proceed?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay. I think let's go ahead and remove it then!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, I'll let you review one more time!

)

metaclass = _make_metaclass(type)
out = metaclass(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we skip creating the metaclass too? And have this line just be type(...

@gabrieldemarmiesse
Copy link
Author

Closing this pull request as #309 already add support for mlx arrays. Thanks @patrick-kidger !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support for MLX

2 participants