-
Hi, I'm new to JAX and wonder if there is an equivalent function for Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 5 replies
-
No, there's no special function for this (though it's been discussed in #9235) In case someone comes across this in a search, here's how you might implement simple pairwise euclidean distances in JAX using broadcasted vector operations, which is probably the most efficient way to compute this robustly in JAX, and is probably how any future pairwise API would be implemented: def pairwise_euclidean(x, y):
assert x.ndim == y.ndim == 2
return jnp.sqrt(((x[:, None, :] - y[None, :, :]) ** 2).sum(-1)) (I say "robustly" because there are ways you can factor-out terms to compute the same result in fewer operations, but it is numerically unstable in edge cases) |
Beta Was this translation helpful? Give feedback.
No, there's no special function for this (though it's been discussed in #9235)
In case someone comes across this in a search, here's how you might implement simple pairwise euclidean distances in JAX using broadcasted vector operations, which is probably the most efficient way to compute this robustly in JAX, and is probably how any future pairwise API would be implemented:
(I say "robustly" because there are ways you can factor-out terms to compute the same result in fewer operations, but it is numerically unstable in edge cases)