jax.numpy.count_nonzero output #10329
-
I just noticed that the output types of indices = jnp.arange(196)%4
a = jnp.count_nonzero(indices_2 == 0, axis=0,keepdims=False) returns a DeviceArray whereas indices = jnp.arange(196)%4
a = np.count_nonzero(indices_2 == 0, axis=0,keepdims=False) returns a scalar. I don't know if this is intended behavior but it is causing me some headache in trying to jit a function as I'm trying to set the size of an array using the number of zeros in some other array. Normally, I would turn the output of the first it of code into a scalar using Any suggestions for getting the number of zeros in an array in a jittable form? Apologies in advance if I'm being dense and missing something obvious. |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
Thanks for the question! So there are a couple things going on here: First, while numpy has special scalar types that are distinct from (and behave differently than) zero-dimensional arrays, JAX made the design choice early to not have separate scalar types. So any function that returns a scalar value will represent it as a zero-dimensional Second, this design choice aside, it sounds like you are trying to create an array whose shape depends on the content of another array. This is fine if you're doing it outside of JIT or other transformations, but inside transformations it is not possible. Why? JAX's compilation and transformation model depends on array shapes and data types being static, that is, known at trace-time. In particular, this means that you cannot JIT-compile any program where the shape of an array depends on the values within another array. This is a fundamental limitation of JAX's tracing model as it is currently implemented (though there is ongoing experimental work toward relaxing this limitation). For some more background on JIT and tracing, you might check out https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#jit-mechanics-tracing-and-static-variables Depending on what you're hoping to do with the resulting value, there are a variety of workarounds you might use, but none of them involve creating an array of dynamic shape. If you can share a more complete example of your program logic, we may be able to suggest how to re-express it in a JIT-compatible way. |
Beta Was this translation helpful? Give feedback.
-
Hi Jake, Thanks for the quick reply. Let me first post the function. def _reduce_indices(indices, p):
"""
Minimally reduce indices to ensure that it can be safely pooled from nside -> nside/(2**p)
-------------
:param indices: array of pixel indices in NEST ordering
:param p: reduction factor, nside -> nside/(2**p)
-------------
returns: array of pixel indices in NEST ordering
"""
indices_2 = indices - indices % (4**p)
indices_2 = indices_2 - np.roll(indices_2, -(4**p)+1)
new_indices = np.repeat(indices[indices_2==0],4**p) ###<--- dynamic shape.
new_indices = np.reshape(new_indices, (-1, 4**p))
new_indices = new_indices + np.arange((4**p))
new_indices = new_indices.flatten()
return new_indices This is a simple utility function to calculate the mask to apply on images on the sphere and it is going to be used by pooling/convolutional layers in a graph convolutional neural network model i'm trying to implement in Here is my naive implementation: @jax.jit
def _reduce_indices_jax(indices, p):
"""
Minimally reduce indices to ensure that it can be safely pooled from nside -> nside/(2**p)
-------------
:param indices: array of pixel indices in NEST ordering
:param p: reduction factor, nside -> nside/(2**p)
-------------
returns: array of pixel indices in NEST ordering
"""
indices_2 = indices - indices % (4**p)
indices_2 = indices_2 - jnp.roll(indices_2, -(4**p)+1)
count_zero = jnp.count_nonzero(indices_2==0) ###<--- DeviceArray instead of scalar.
new_indices = jnp.repeat(indices[indices_2==0],4**p,total_repeat_length=count_zero*(4**p))
new_indices = jnp.reshape(new_indices, (-1, 4**p))
new_indices = new_indices + jnp.arange((4**p))
new_indices = new_indices.flatten()
return new_indices I can imagine padding and then dealing with the padded values later on so it's not a huge issue but I was just wondering the reasoning behind the choice of returning Thanks again for the quick reply! |
Beta Was this translation helpful? Give feedback.
Thanks for the question! So there are a couple things going on here:
First, while numpy has special scalar types that are distinct from (and behave differently than) zero-dimensional arrays, JAX made the design choice early to not have separate scalar types. So any function that returns a scalar value will represent it as a zero-dimensional
DeviceArray
.Second, this design choice aside, it sounds like you are trying to create an array whose shape depends on the content of another array. This is fine if you're doing it outside of JIT or other transformations, but inside transformations it is not possible. Why? JAX's compilation and transformation model depends on array shapes and data type…