Add support for MLX arrays#301
Add support for MLX arrays#301gabrieldemarmiesse wants to merge 6 commits intopatrick-kidger:mainfrom
Conversation
jaxtyping/_array_types.py
Outdated
| if type(out) is tuple: | ||
| array_type, name, dtypes, dims, index_variadic, dim_str = out | ||
| # Nanobind classes can't be a base type | ||
| can_subclass = array_type is Any or "nanobind." in str(type(array_type)) |
There was a problem hiding this comment.
I was afraid things might look something like this!
FWIW we could maybe just remove subclassing altogether -- I'm not sure how much this buys us, and we're really threading the needle here to do so -- WDYT?
(I think I might have introduced this subclassing to make things work better when using type annotations with plum, although I might have that wrong. I think that's a use-case that can be supported in other ways, though.)
There was a problem hiding this comment.
I have no opinion on the matter. I'll remove the subclassing entirely and see where it goes
There was a problem hiding this comment.
______________________________________________________ test_subclass _______________________________________________________
def test_subclass():
> assert issubclass(Float[Array, ""], Array)
E AssertionError: assert False
E + where False = issubclass(<class 'jaxtyping.Float[Array, '']'>, Array)
test/test_array.py:607: AssertionError
The only test failing is this one: test/test_array.py::test_subclass - AssertionError: assert False. Every assertion on this test is failing.
I find it strange that a type hint need to be a subclass of a type. If that's not needed, we can remove this completely. How should I proceed?
There was a problem hiding this comment.
Okay. I think let's go ahead and remove it then!
There was a problem hiding this comment.
Done, I'll let you review one more time!
| ) | ||
|
|
||
| metaclass = _make_metaclass(type) | ||
| out = metaclass( |
There was a problem hiding this comment.
Can we skip creating the metaclass too? And have this line just be type(...
|
Closing this pull request as #309 already add support for mlx arrays. Thanks @patrick-kidger ! |
Fixes #299