-
So, I am trying to write a function to implement a Top-k pooling layer in Jraph, but I am running into issues with vmap. It definitely has to do how Masking+Jit does not really work with dynamic arrays, but I was wondering if there is any workaround to do the same thing. Here's a minimal example of what I want to do. Starting from this graph: Here's the code to do it for a graph:
This works fine but the problem is when I try vmap this to work for a batch of graphs, this won't work:
Any ideas on how to work around this? |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
The issue here is that you're attempting to create a dynamically-shaped array (see JAX sharp bits: dynamic shapes). The semantics of The solution would be to re-express your algorithm in a way that uses statically-shaped arrays. For example, if the number of nonzero entries is known statically, you can pass that to the |
Beta Was this translation helpful? Give feedback.
-
Yes, I completely agree. The problem here is that there is no way to define |
Beta Was this translation helpful? Give feedback.
The issue here is that you're attempting to create a dynamically-shaped array (see JAX sharp bits: dynamic shapes).
The semantics of
jnp.nonzero
are that it returns an array whose length is the number of nonzero elements in its argument; this length cannot be known at compile time, and so using it in this way is not compatible with transformations likejit
andvmap
which require all arrays to be statically-shaped.The solution would be to re-express your algorithm in a way that uses statically-shaped arrays. For example, if the number of nonzero entries is known statically, you can pass that to the
size
argument ofjnp.nonzero
. I'm not sure whether this is relevant in your case: I'm not f…