lax.scan very slow on GPU #10233
-
I have an array of complex numbers import jax.numpy as jnp
from jax import jit, vmap, lax, random
from jax.config import config
config.update('jax_platform_name', 'gpu')
key = random.PRNGKey(42)
key, subkey = random.split(key)
z = random.uniform(subkey, (10, 2000))
@jit
def linear_sum_assignment(a, b):
idcs_d = jnp.argsort(jnp.abs(b - a[:, None]), axis=1)
idcs_final = jnp.repeat(999, len(a))
def f(carry, idcs_d_row):
i, idcs_final = carry
cond1 = jnp.isin(idcs_d_row[0], jnp.array(idcs_final))
cond2 = jnp.isin(idcs_d_row[1], jnp.array(idcs_final))
idx_closest = jnp.where(
cond1, jnp.where(cond2, idcs_d_row[2], idcs_d_row[1]), idcs_d_row[0]
)
idcs_final = idcs_final.at[i].set(idx_closest)
return (i + 1, idcs_final), idx_closest
_, res = lax.scan(f, (0, idcs_final), idcs_d)
return res
@jit
def f(z):
def apply_linear_sum_assignment(carry, z_slice):
idcs = linear_sum_assignment(carry, z_slice)
return z_slice[idcs], z_slice[idcs]
_, z_permuted = lax.scan(
apply_linear_sum_assignment, z[:, 0], z.T
)
return z_permuted
# Slow on GPU
res = f(z) This code is more than an order of magnitude slower on the GPU than the CPU. Am I correct in thinking that this result is expected because the input to |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 3 replies
-
|
Beta Was this translation helpful? Give feedback.
-
Do you need an exact minimization, or an approximation one is okay? |
Beta Was this translation helpful? Give feedback.
-
BTW, |
Beta Was this translation helpful? Give feedback.
scan
is slow on accelerator, especially when there is inplace-update to carry(strong data dependency).And you use two nested
scan
, which cannot leverage parallelism capability of GPU at all.