-
I implemented RBF kernel for pairwise comparison in JAX. When I compared its speed against Could anyone shed me some light on why my JAX implementation has a performance issue here? You can check out my Colab notebook for the same code as belows: import numpy as np
import jax.numpy as jnp
from jax import jit, vmap
from functools import partial
import sklearn.metrics.pairwise as pairwise
def pairwise_compute(kernel, xs):
return vmap(vmap(kernel, in_axes=(None, 0)), in_axes=(0, None))(xs, xs)
def rbf_kernel_jax(x, y, gamma):
return jnp.exp(-gamma * (jnp.linalg.norm(x - y)**2))
X = np.random.randn(1000, 500)
jax_func = jit(partial(pairwise_compute, partial(rbf_kernel_jax, gamma=0.5)))
sklearn_func = partial(pairwise.rbf_kernel, gamma=0.5)
# jit compile
_ = jax_func(X)
# profile JAX
%timeit -n 10 jax_func(X).block_until_ready()
# loops, best of 5: 671 ms per loop
# # profile numpy
%timeit -n 10 sklearn_func(X, gamma=0.5)
# 10 loops, best of 5: 71.7 ms per loop |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
I think the reason sklearn's version is faster is because it's using a factored computation to compute euclidean distances; see https://github.com/scikit-learn/scikit-learn/blob/aebd9f6295da57a4b3f03d00a3527f8c4b5f811f/sklearn/metrics/pairwise.py#L369-L373 This uses many fewer flops to compute pairwise euclidean norms, at the expense of accuracy. The JAX algorithm you created is doing a much more FLOPS-intensive calculation that does not have this accuracy problem. I'll note that even with this more expensive algorithm, on GPU your JAX code is several times faster than the sklearn approach. I think this is roughly equivalent to the algorithm sklearn is using: @jit
def rbf_kernel(X, gamma):
XX = (X ** 2).sum(1)
XY = X @ X.T
sq_distances = XX[:, None] + XX - 2 * XY
return jnp.exp(-gamma * sq_distances) When I run your benchmarks, I find the JAX version is about 2x faster than sklearn on a Colab CPU. The downside here (for both JAX and sklearn) is that in cases where both |
Beta Was this translation helpful? Give feedback.
I think the reason sklearn's version is faster is because it's using a factored computation to compute euclidean distances; see https://github.com/scikit-learn/scikit-learn/blob/aebd9f6295da57a4b3f03d00a3527f8c4b5f811f/sklearn/metrics/pairwise.py#L369-L373
This uses many fewer flops to compute pairwise euclidean norms, at the expense of accuracy. The JAX algorithm you created is doing a much more FLOPS-intensive calculation that does not have this accuracy problem.
I'll note that even with this more expensive algorithm, on GPU your JAX code is several times faster than the sklearn approach.
I think this is roughly equivalent to the algorithm sklearn is using: