Skip to content
Discussion options

You must be logged in to vote

No, there's no special function for this (though it's been discussed in #9235)

In case someone comes across this in a search, here's how you might implement simple pairwise euclidean distances in JAX using broadcasted vector operations, which is probably the most efficient way to compute this robustly in JAX, and is probably how any future pairwise API would be implemented:

def pairwise_euclidean(x, y):
  assert x.ndim == y.ndim == 2
  return jnp.sqrt(((x[:, None, :] - y[None, :, :]) ** 2).sum(-1))

(I say "robustly" because there are ways you can factor-out terms to compute the same result in fewer operations, but it is numerically unstable in edge cases)

Replies: 1 comment 5 replies

Comment options

You must be logged in to vote
5 replies
@mattjj
Comment options

@YouJiacheng
Comment options

@mattjj
Comment options

@yuhanfu
Comment options

@jakevdp
Comment options

Answer selected by yuhanfu
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
4 participants