How to efficiently do functional evaluation with a condition based on an array with integers. #14428
-
I want to efficiently perform conditional function evaluation based on an array of integers and other arrays with real numbers serving as input for those functions. I hope to find a JAX-based solution that provides significant performance improvements over a for-loop approach that I describe below:
The role of "i_ar" is to act as an index that selects one of the four functions from the list g_i. "i_ar" is an array of integers, with each integer representing an index in the g_i list. On the other hand, x_ar, y_ar, z_ar, and u_ar are arrays of real numbers that are input to the functions selected by i_ar. I suspect that this difference in nature between i_ar and x_ar, y_ar, z_ar, and u_ar is what could be difficult to find a JAX way that would be more efficient replacement of the for loop above. Any ideas how to use JAX (or something else) to replace the foor loop to obtain 'total' more efficiently? I have tried naively using vmap/pmap, but failed. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
You might use def g(i, x, y, z, u):
return lax.switch(i, g_i, x, y, z, u)
# or
return jnp.choose(i, [func(x, y, z, u) for func in g_i], mode='wrap')
@jit
def total(i, x, y, z, u):
res = vmap(g, in_axes=(0, None, None, None, None))(i, x, y, z, u)
return jnp.sum(res, axis=0)
total(i_ar, x_ar, y_ar, z_ar, u_ar)
IIUC, this is actually a different question, and is exactly where Example code heredef f(x, y, z, u):
# let's pretend it's some complex pointwise function where broadcasting doesn't work
assert all(a.shape == () for a in [x, y, z, u])
return jnp.array(0)
x_ar = y_ar = z_ar = jnp.zeros(10)
u_ar = jnp.zeros(7)
# nested vmaps saves the day!
# batched_f deals with the said 'difference in nature' between xyz and u
batched_f = vmap(vmap(f, in_axes=(0, 0, 0, None)), in_axes=(None, None, None, 0))
assert batched_f(x_ar, y_ar, z_ar, u_ar).shape == (len(u_ar), len(x_ar))
# g works with arrays (x, y, z, u), and a single i
def g(i, x, y, z, u):
g_i = [batched_f] * 3 # mock 3 branches
return lax.switch(i, g_i, x, y, z, u)
# so we vmap again
batched_g = vmap(g, in_axes=(0, None, None, None, None))
i_ar = jnp.arange(11) % len(g_i)
assert batched_g(i_ar, x_ar, y_ar, z_ar, u_ar).shape == (len(i_ar), len(u_ar), len(x_ar)) |
Beta Was this translation helpful? Give feedback.
You might use
lax.switch
orjnp.choose
forg
(check out all the different ways to branch in JAX here), andjnp.sum
along the first dimension at the very end. The performance is probably not that large for your example, except when running on accelerator.IIUC, this is actually a different question, and is exactly wh…