Skip to content
Discussion options

You must be logged in to vote

I think the best way to improve this would be to remove the fori_loop altogether: fori_loop is inherently sequential, and nothing in your computation requires sequential dependence. Using vectorized operations will allow the compiler to execute steps in a vectorized manner, and should lead to much better performance:

pts_x_id = jnp.digitize(pts[0,:],binsx).astype(np.uint32)
pts_y_id = jnp.digitize(pts[1,:],binsy).astype(np.uint32)
pts_z = pts[2,:]

img_sum_2 = jnp.zeros((Npix,Npix),dtype=jnp.float32).at[pts_y_id, pts_x_id].add(pts_z)
img_counts_2 = jnp.zeros((Npix,Npix),dtype=jnp.uint32).at[pts_y_id, pts_x_id].add(1)

np.testing.assert_allclose(img_sum, img_sum_2)
np.testing.assert_array_…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@jecampagne
Comment options

Answer selected by jecampagne
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants