Replies: 1 comment
-
Why not set @partial(jax.jit, static_argnums=(1,))
def top_k_hot(arr, k):
val, ind = jax.lax.top_k(arr, k) # val.size == ind.size == k
ret = jnp.zeros(arr.shape[0], jnp.int32)
ret = jax.ops.index_update(ret, ind, 1)
return ret Been trying to do something similar but with 2D arrays and ...
val, idx = jax.lax.top_k(prob_tensor, topk)
tops = jax.nn.one_hot(idx, prob_tensor.shape[-1], dtype=jnp.int32)
return tops.sum(axis=-2) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
tl;dr: I want to write this
top_k_hot
function which is 1. jittable, 2. minimal performance loss:which takes two input:
arr
: array of size nk
: integer kand outputs:
array of size n which is marked 1 in the top-k index of
arr
marked 0 if its not in the top-k index of
arr
simple example:
A simple implementation would be to use
jax.lax.top_k
.However, this is not jit-able because k is variable,
and
jax.lax.top_k
outputs the array of size k.There are some alternatives I came up with, but it still doesn't solves the problem that
jax.lax.top_k
outputs variable-sized array.Arg-sorting the array and using the sliced index also doesn't work because
jax.ops.index_update
should not take variablestart:stop:step index to be jit-able.
Sorting the array and taking the right value works, but it acts differently and I'm quite worried about the performance.
Partial sorting could improve a performance a little.
But this function will execute so frequently, I want optimal solution if possible.
So it seems like one can implement jit-able
top_k_hot
function,but I can only do it with sad performance loss for now.
Any help will be appreciated!!
Also, thanks for making this awesome library.
It feels like I've finally found a right tool to use.
Beta Was this translation helpful? Give feedback.
All reactions