vmap with reduction #8490
-
I have a situation in my code where I have a very large matrix However, I am a bit stuck on the inverse problem: Given a bunch of jobs that can modify shared matrix I saw two other functions EDIT: import jax.numpy as jnp
from jax import grad, jit, vmap
import math
import numpy as np
import jax.nn as jnn
import timeit
import jax
n_particles = 3000
n_grid = 128
dx = 1./n_grid
inv_dx = 1./n_grid
dim = 2
mass = 1.0
def p2g_accum(x, v, affine):
grid_v_in = jnp.zeros((n_grid, n_grid, dim))
#compute some indexing information
base = jnp.int32(x * inv_dx - 0.5)
fx = x * inv_dx - base
w = [0.5 * (1.5 - fx)**2, 0.75 - (fx - 1)**2, 0.5 * (fx - 0.5)**2]
for i in range(3):
for j in range(3):
#compute some more indexing information
offset = jnp.array([i, j])
dpos = (jnp.array([i, j]) - fx) * dx
weight = w[i][0] * w[j][1]
idx = base + offset
#make masks so we don't have to perform slow indexing
mask = jnn.one_hot(idx[0] * n_grid + idx[1], n_grid * n_grid)
mask = jnp.expand_dims(jnp.reshape(mask, (n_grid, n_grid)), -1)
#now let's do some complex computation
a = jnp.expand_dims(jnp.expand_dims(weight * (mass * v + dpos @ affine), 0), 0)
grid_v_in += a * mask
return grid_v_in
p2_j = vmap(p2g_accum)
def p2g_reduce(x, v, affine):
grid_v_in = p2_j(x, v, affine)
grid_v_in = jnp.sum(grid_v_in, axis=0)
return grid_v_in
p2r = jax.checkpoint(jit(p2g_reduce))
if __name__ == '__main__':
x = jnp.array(np.random.rand(n_particles, dim))
v = jnp.array(np.random.rand(n_particles, dim))
affine = jnp.array(np.random.rand(n_particles, dim, dim))
p2r(x, v, affine)
number = 10
print(timeit.timeit(lambda : p2r(x, v, affine).block_until_ready(), number=number) / number) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 12 replies
-
Hi - thanks for the question. I'm a bit unclear on what kind of operation you're describing. Can you edit the question with a short example of the kind of operation you want to do? What do you mean by "a bunch of jobs that can modify shared matrix |
Beta Was this translation helpful? Give feedback.
Hi - thanks for the question. I'm a bit unclear on what kind of operation you're describing. Can you edit the question with a short example of the kind of operation you want to do? What do you mean by "a bunch of jobs that can modify shared matrix
m
"?