Is this pattern already implemented somewhere? #16259
-
I noticed that in writing both a 4-bit matrix implementation and a LoRA implementation that there's a common pattern of "thing which represents a matrix but avoids instantiating a dense matrix if possible." For example with 4-bit quantization you prefer to use a specialized matmul kernel, but if you need to do something like calculate the norm of a matrix you can always full materialize it. It's pretty easy to implement this pattern with tracers, but handling some of the stuff like dealing with pjit and remat primitives every time is annoying. Is this special case already implemented somewhere? For example I know there's a |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 20 replies
-
No, I don't believe there's any stock pattern for this kind of thing. One note on |
Beta Was this translation helpful? Give feedback.
-
I have two other cases where this would be a useful feature to have -- sparse arrays, and symbolic zeros. (The latter outside of AD, as a general jaxtype.) I note that sparse arrays already exist and have some limited support. So would it be possible to build into JAX a single unified approach here? |
Beta Was this translation helpful? Give feedback.
No, I don't believe there's any stock pattern for this kind of thing. One note on
__jax_array__
is that it's not a fully-implemented API (and probably never will be, see #4725), so I would not recommend relying on it (cc/ @mattjj, who knows more about the plans for this).