How to avoid loops to make paralell filing of a 2D array #9191
Unanswered
jecampagne
asked this question in
Q&A
Replies: 2 comments 2 replies
-
It's hard to give a definite answer without more knowledge of what import numpy as np
x_val = np.linspace(0, 100, 10)
y_val = np.linspace(0, 100, 20)
result = x_val[:, None] * y_val[None, :]
print(result.shape)
# (10, 20) |
Beta Was this translation helpful? Give feedback.
2 replies
-
I think that is exactly what you need. def f(x, y):
return jnp.sin(x) * jnp.cos(y) # any complicated computation based on a pair of elements
y_vmap_f = jax.vmap(f, (None, 0))
xy_vmap_f = jax.vmap(y_vmap_f, (0, None))
nx, ny = 20, 10
xs = jnp.arange(nx)
ys = jnp.arange(ny)
out = xy_vmap_f(xs, ys)
# verification of correctness
assert out.shape == (nx, ny)
for i in range(nx):
for j in range(ny):
assert out[i, j] == f(xs[i], ys[j]) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello,
here is a very simplified version of my use-case:
This code seems horrible and I only use np.sin/np.cos fucntions to get something simple but the code is fairly complicated to get the values of result1[i,j] and result2[i,j].
I would like to now how I can use jax.lax.scan or for_loop to get this use-case using the powerful jax/lax paralell computing. Any help is welcome.
Beta Was this translation helpful? Give feedback.
All reactions