remove the array with duplicate element #12935
Answered
by
jakevdp
yiminghwang
asked this question in
Q&A
-
Hi, there, Is there any good way to remove the duplicated elements in DeviceArray? a_list = [[1,0,0,1],[1,0,1,1],[0,1,1,0],[1,0,0,1],[1,0,1,1]]
res_list = [list(t) for t in set(tuple(list(_)) for _ in a_list)] Is there any effecient implementation using JAX? directly use fori_loop to filter? |
Beta Was this translation helpful? Give feedback.
Answered by
jakevdp
Oct 22, 2022
Replies: 1 comment 1 reply
-
You can do this with import jax.numpy as jnp
a = jnp.array([[1,0,0,1],[1,0,1,1],[0,1,1,0],[1,0,0,1],[1,0,1,1]])
print(jnp.unique(a, axis=0))
# [[0 1 1 0]
# [1 0 0 1]
# [1 0 1 1]] |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
yiminghwang
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
You can do this with
jnp.unique
: