Can I use scipy.spatial.distance in jax? #15862
Replies: 1 comment
-
Hi - thanks for the question! Unfortunately there's no JAX wrapper for Fortunately, though, pairwise distance operations are pretty straightforward to implement efficiently in JAX. A big reason that iimport jax
import jax.numpy as jnp
from scipy.spatial import distance
import numpy as np
@jax.jit
def cdist(x, y):
return jnp.sqrt(jnp.sum((x[:, None] - y[None, :]) ** 2, -1))
rng = np.random.RandomState(0)
x = rng.rand(2, 3)
y = rng.rand(4, 3)
print(distance.cdist(x, y))
# [[0.41689499 0.19662693 0.57216693 0.86543108]
# [0.57586803 0.41860234 0.76350759 0.63809564]]
print(cdist(x, y))
# [[0.416895 0.19662698 0.572167 0.8654311 ]
# [0.575868 0.41860238 0.76350766 0.6380957 ]] |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
I know there is jax.scipy in jax to realize the corresponding function in scipy. But the scipy.spatial library does not seem to be implemented in jax.scipy. I want to use scipy.spatial.distance.pdist with jax.tracer. Do I need to implement it myself using jnp?
Beta Was this translation helpful? Give feedback.
All reactions