Skip to content
Discussion options

You must be logged in to vote

Assuming your function f handles broadcasting semantics, you could do this:

result = f(Xvals[:, None, None], Yvals[None, :, None], Zvals[None, None, :]).sum((0, 1, 2))

If not, you could do something similar with vmap:

f_mapped = vmap(vmap(vmap(f, in_axes=(0, None, None)), in_axes=(None, 0, None)), in_axes=(None, None, 0))(Xvals, Yvals, Zvals).sum((0, 1, 2))

I suspect either of these would be much more performant than fori_loop.

But if you need fori_loop, you can express this by nesting three loop calls:

from jax import lax, vmap

x = jnp.arange(5)
y = jnp.arange(10)
z = jnp.arange(15)

def f(x, y, z):
  return x * y * z

print(sum(f(x[i], y[j], z[k])
          for i in range(len(x))
     …

Replies: 1 comment 4 replies

Comment options

You must be logged in to vote
4 replies
@gerdm
Comment options

@jakevdp
Comment options

@mattjj
Comment options

@mattjj
Comment options

Answer selected by gerdm
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants