Skip to content
Discussion options

You must be logged in to vote

You might use lax.switch or jnp.choose for g (check out all the different ways to branch in JAX here), and jnp.sum along the first dimension at the very end. The performance is probably not that large for your example, except when running on accelerator.

def g(i, x, y, z, u):
  return lax.switch(i, g_i, x, y, z, u)
  # or
  return jnp.choose(i, [func(x, y, z, u) for func in g_i], mode='wrap')

@jit
def total(i, x, y, z, u):
  res = vmap(g, in_axes=(0, None, None, None, None))(i, x, y, z, u)
  return jnp.sum(res, axis=0)

total(i_ar, x_ar, y_ar, z_ar, u_ar)

this difference in nature between i_ar and x_ar, y_ar, z_ar, and u_ar

IIUC, this is actually a different question, and is exactly wh…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@stefanvuckovic1
Comment options

Answer selected by stefanvuckovic1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants