Faster fori_loop? #21512
-
Hello, Here is a working snippet (I'm running on a V100 32G) that intend to produce a 2D image Npix x Npix, from 3D points where their X,Y values are digitized according to some bin arrays and we collect the Z values to get at the end of the day (but faster! ) the Z-mean values of all the points belonging to the same image pixel. My problem is that I need to do so for more or less 10^10 pts that are store in 15 batches (files) of about 6.7 10^8 pts. So if one finds a better implementation, you are the welcome. (nb. I can also run on 4 GPUs if one finds a parallelized version can update the img_sum & img_counts in concurrence.) Thanks %pylab inline
import numpy as np
import jax
import jax.numpy as jnp
# create a bunch of 3D points
pts = jax.random.uniform(jax.random.PRNGKey(42),(3,1_000_000),dtype=jnp.float32)
# Image size & binning
Ximgmin,Ximgmax=0.,1.
Yimgmin,Yimgmax=0.,1.
Npix = 100
binsx = jnp.linspace(Ximgmin,Ximgmax,Npix-1,dtype=jnp.float32)
binsy = jnp.linspace(Yimgmin,Yimgmax,Npix-1,dtype=jnp.float32)
# the function to optimize
def digitpts_and_updateimg_v1(pts,img_sum,img_counts):
pts_x_id = jnp.digitize(pts[0,:],binsx).astype(np.uint32)
pts_y_id = jnp.digitize(pts[1,:],binsy).astype(np.uint32)
pts_z = pts[2,:]
jax.debug.print("digit done")
def body(ipt,carry):
img_sum, img_counts=carry
i = pts_x_id[ipt]
j = pts_y_id[ipt]
z = pts_z[ipt]
return img_sum.at[j,i].add(z), img_counts.at[j,i].add(1)
Npts = pts.shape[1]
img_sum,img_counts = jax.lax.fori_loop(0,Npts,body,(img_sum,img_counts))
return img_sum,img_counts
# init image sum/count
img_sum = jnp.zeros((Npix,Npix),dtype=jnp.float32)
img_counts = jnp.zeros((Npix,Npix),dtype=jnp.uint32)
# Go....
img_sum,img_counts= digitpts_and_updateimg_v1(pts,img_sum,img_counts)
# Plot
imshow(img_sum/img_counts);colorbar(); |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
I think the best way to improve this would be to remove the pts_x_id = jnp.digitize(pts[0,:],binsx).astype(np.uint32)
pts_y_id = jnp.digitize(pts[1,:],binsy).astype(np.uint32)
pts_z = pts[2,:]
img_sum_2 = jnp.zeros((Npix,Npix),dtype=jnp.float32).at[pts_y_id, pts_x_id].add(pts_z)
img_counts_2 = jnp.zeros((Npix,Npix),dtype=jnp.uint32).at[pts_y_id, pts_x_id].add(1)
np.testing.assert_allclose(img_sum, img_sum_2)
np.testing.assert_array_equal(img_counts, img_counts_2) |
Beta Was this translation helpful? Give feedback.
I think the best way to improve this would be to remove the
fori_loop
altogether:fori_loop
is inherently sequential, and nothing in your computation requires sequential dependence. Using vectorized operations will allow the compiler to execute steps in a vectorized manner, and should lead to much better performance: