-
Say I have an array of inputs Something like this:
I think the direct equivalent in jax would be a How would you recommend implementing this process in jax? |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 2 replies
-
Wouldn't |
Beta Was this translation helpful? Give feedback.
-
Searching around for a "mapreduce" primitive in jax lead me to this thread. Looks like the answer is in fact to use |
Beta Was this translation helpful? Give feedback.
-
I would suggest using the most straightforward approach: @jit
def argmin_and_min(x):
y = f(x)
i = jnp.argmin(y)
return i, y[i]
lowest_i, lowest_y = argmin_and_min(x) You ask Above you mentioned being concerned about the need to allocate the full array |
Beta Was this translation helpful? Give feedback.
Searching around for a "mapreduce" primitive in jax lead me to this thread. Looks like the answer is in fact to use
vmap
with the reduction applied, though apparentlyscan
can work too under some circumstances.... But in any casejit
is crucial to optimize the XLA computations.