-
Let's say I've got some ShapedArray holding numbers, where all the numbers are greater than zero. I'd like to replace any duplicates in those arrays with zeros. For example:
Does anyone know of a fast way to do this on the GPU? Or is it impossible to do entirely on the GPU since you have to check the values? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
With a hash table, this would be an O[N] algorithm. Unfortunately, jax does not have any hash table data structure (a deficiency it inherits from XLA, due to the fact that accelerators are not optimized for such operations) so you'll probably be limited to an O[N²] solution, in which you compare each entry to every previous entry in a vectorized fashion. Here's one approach along those lines: import jax.numpy as jnp
def set_duplicates_to(x, val):
x = jnp.asarray(x)
N, = x.shape
assert x.ndim == 1
# set to x-1 so columns do not match x
M = jnp.zeros((N - 1, N)).at[:].set(x - 1)
col = jnp.arange(N)
row = jnp.arange(1, N)[:, None]
# note: in JAX, overflowing indices are dropped.
M = M.at[row, row + col].set(x)
# compare each entry of x to every previous entry.
return x.at[(x == M).any(0)].set(0)
x = jnp.array([1, 2, 3, 1, 2, 4])
set_duplicates_to(x, 0)
# DeviceArray([1, 2, 3, 0, 0, 4], dtype=int32) This will not work correctly if your array is floating point and has any NaN or inf values (it requires x != x-1 and x == x to hold), but you could make it work by first converting the inputs bitwise to integers of the appropriate width. |
Beta Was this translation helpful? Give feedback.
With a hash table, this would be an O[N] algorithm. Unfortunately, jax does not have any hash table data structure (a deficiency it inherits from XLA, due to the fact that accelerators are not optimized for such operations) so you'll probably be limited to an O[N²] solution, in which you compare each entry to every previous entry in a vectorized fashion.
Here's one approach along those lines: