Skip to content
Discussion options

You must be logged in to vote

You can do this with jnp.unique:

import jax.numpy as jnp
a = jnp.array([[1,0,0,1],[1,0,1,1],[0,1,1,0],[1,0,0,1],[1,0,1,1]])
print(jnp.unique(a, axis=0))
# [[0 1 1 0]
#  [1 0 0 1]
#  [1 0 1 1]]

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@yiminghwang
Comment options

Answer selected by yiminghwang
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants