Replies: 1 comment 3 replies
-
Hi, here is a quick solution I whipped up, which I think works. To be honest I'm not quite sure what your operation is doing, and I suspect there might be a way to re-write it with masks so that it doesn't take so long to compile. As for general tips for re-writing, I find that it helps to have a clear idea of what the jit compiler has access to and what it doesn't. In this case, we are allowed to use I and J inside a jit function, as they are fixed given the shape of the input (and the jit compiler knows the shape of the input when compiling). So we can use quantities like np.shape(X), we can write loops where the numbers of iterations depend on the shape or quantities derived from the shape, etc. An annoying detail is that the compiler is typically super-linear in the length of python loops, which are statically unrolled when compiled. So the advice is generally to avoid long python loops if possible. Here there is likely some way to avoid the loops via a
|
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
I have a function that constructs a matrix and then loop over its diagonal axis to update the entries. (e.g. loop from top-left diagonal line till bottom-right diagonal line).
In numpy, it is something like this:
However, I have hard time converting the logic to jax jit-able code. Main bottleneck I found is that the length of the indices (i.e.
I
,J
) that I use for indexing the diagonal line depends on the jax operations (i.e. min, max). So it seems the logic above cant be directly translated to jit-able code.Do you have any advice to rephrase the logic so that it is jit-able? Thanks!
Beta Was this translation helpful? Give feedback.
All reactions