Error related to static arguments when using classes in JAX. #8031
-
Dear JAX experts, I am a fairly new user of JAX but have been pleasantly surprised with the rewards of using it. I am trying to implement a class structure (yes, I have read the discussion in 1567 and the threads leading up to it and am trying out classes and namedtuples simultaneously).I run the following code
And come across the error
I am not sure why this should be happening. The error is on the line 82 "supmat = jnp.zeros((dim_super, dim_super), dtype='float32')". I declare CENMULT as a static argument since the dimension of "supmat" depends on CENMULT. And NEIGHBOUR_DICT will change only when CENMULT changes. So, I was under the impression that should work. Any suggestions? |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
Thanks for the question! When you run dim_super = jnp.sum(dimX_submat[0, :]) within a JIT-compiled context, the result is a traced value regardless of whether the input is static. The sizes of matrices cannot depend on traced values, so when you run I haven't looked closely at your full algorithm, but in order to use it within JAX transforms like JIT, you will have to write it in such a way that array sizes are static: one useful pattern is to use standard numpy ops when dealing with static values. For more information on this, the How to think in JAX document is a useful read. |
Beta Was this translation helpful? Give feedback.
-
AH! I see. I was actually confused for the longest time as to why the above example fails but why the following simple code works
I tried out a bunch of things and realized as long as the scalars defining the dimensions of arrays that are created in jitted functions in JAX aren't traced values it works. In the above example, it isn't a traced value. But in my original post it got converted to a traced value after getting acted upon by jax.numpy operations. Now your suggestions make a lot of sense. Just re-iterating so that I know I got this correctly: Any such scalar which determines the dimensions of arrays should not be traced values. So, throughout the code, to ensure they don't get converted to traced values, we need to keep applying numpy operations instead of jax.numpy operations. Is this correct? My final question would then be: In such cases, is there a chance that we might be compromising the speedup that JAX gives us if there are numpy operations (in my mind are these speed breakers)? Would the code be slowed down due to such operations even when there is no re-compilation (meaning the cases where the static arguments are not changing)? |
Beta Was this translation helpful? Give feedback.
Thanks for the question!
When you run
within a JIT-compiled context, the result is a traced value regardless of whether the input is static. The sizes of matrices cannot depend on traced values, so when you run
jnp.zeros((dim_super, dim_super))
you get the tracer conversion error.I haven't looked closely at your full algorithm, but in order to use it within JAX transforms like JIT, you will have to write it in such a way that array sizes are static: one useful pattern is to use standard numpy ops when dealing with static values.
For more information on this, the How to think in JAX document is a useful read.