-
|
Prior to #10278, I would've written a nested for-loop in jax to accumulate all possible values of a function I understand that this API is being deprecated in favour of Also, why is the Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 4 replies
-
|
Assuming your function result = f(Xvals[:, None, None], Yvals[None, :, None], Zvals[None, None, :]).sum((0, 1, 2))If not, you could do something similar with f_mapped = vmap(vmap(vmap(f, in_axes=(0, None, None)), in_axes=(None, 0, None)), in_axes=(None, None, 0))(Xvals, Yvals, Zvals).sum((0, 1, 2))I suspect either of these would be much more performant than But if you need from jax import lax, vmap
x = jnp.arange(5)
y = jnp.arange(10)
z = jnp.arange(15)
def f(x, y, z):
return x * y * z
print(sum(f(x[i], y[j], z[k])
for i in range(len(x))
for j in range(len(y))
for k in range(len(z))))
# 47250
print(f(x[None, None, :], y[None, :, None], z[:, None, None]).sum())
# 47250
print(vmap(vmap(vmap(f, in_axes=(0, None, None)), in_axes=(None, 0, None)), in_axes=(None, None, 0))(x, y, z).sum())
# 47250
print(lax.fori_loop(0, len(x),
lambda i, vx: vx + lax.fori_loop(0, len(y),
lambda j, vy: vy + lax.fori_loop(0, len(z),
lambda k, vz: vz + f(x[i], y[j], z[k]),
0),
0),
0))
# 47250 |
Beta Was this translation helpful? Give feedback.
Assuming your function
fhandles broadcasting semantics, you could do this:If not, you could do something similar with
vmap:I suspect either of these would be much more performant than
fori_loop.But if you need
fori_loop, you can express this by nesting three loop calls: