Skip to content
Discussion options

You must be logged in to vote

With a hash table, this would be an O[N] algorithm. Unfortunately, jax does not have any hash table data structure (a deficiency it inherits from XLA, due to the fact that accelerators are not optimized for such operations) so you'll probably be limited to an O[N²] solution, in which you compare each entry to every previous entry in a vectorized fashion.

Here's one approach along those lines:

import jax.numpy as jnp

def set_duplicates_to(x, val):
  x = jnp.asarray(x)
  N, = x.shape
  assert x.ndim == 1

  # set to x-1 so columns do not match x
  M = jnp.zeros((N - 1, N)).at[:].set(x - 1)
  col = jnp.arange(N)
  row = jnp.arange(1, N)[:, None]

  # note: in JAX, overflowing indices are dr…

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@ttt733
Comment options

@jakevdp
Comment options

@ttt733
Comment options

Answer selected by ttt733
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants