Are custom datatypes supported for Jax ? #19512
-
I found this: #5356 I would like to create a wrapper around numpy or jax arrays to support some custom operations and have it fall back to jax.numpy operations wherever it can. Is something like this possible ? |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 10 replies
-
No, unfortunately there’s not currently much support for that kind of extensibility, but it’s something we’re thinking about. Can you say more about what you have in mind? e.g. would a mechanism like |
Beta Was this translation helpful? Give feedback.
-
Could you say a little bit more about what your exact use case is? My Qax transform can get you some custom behavior, although part of the point is to make it look like your object is just an ordinary array. This is basically a helper which makes writing a pretty common type of JAX transform easier. |
Beta Was this translation helpful? Give feedback.
-
+1 for Davis' recommendation of Qax as the appropriate solution to this problem. And of course, I'll also advertise my own take on this -- Quax -- which as the name suggests is inspired by Davis' work. We've just done a 0.0.3 release and now have some shiny new documentation. (@Jacob-Stevens-Haas -- you mention wanting named axes. As a toy implementation of exactly that, see |
Beta Was this translation helpful? Give feedback.
No, unfortunately there’s not currently much support for that kind of extensibility, but it’s something we’re thinking about. Can you say more about what you have in mind? e.g. would a mechanism like
__jax_array__
be sufficient, or would you want JAX apis to preserve the type of the objects on output?