Skip to content
Discussion options

You must be logged in to vote

I think the reason sklearn's version is faster is because it's using a factored computation to compute euclidean distances; see https://github.com/scikit-learn/scikit-learn/blob/aebd9f6295da57a4b3f03d00a3527f8c4b5f811f/sklearn/metrics/pairwise.py#L369-L373

This uses many fewer flops to compute pairwise euclidean norms, at the expense of accuracy. The JAX algorithm you created is doing a much more FLOPS-intensive calculation that does not have this accuracy problem.

I'll note that even with this more expensive algorithm, on GPU your JAX code is several times faster than the sklearn approach.

I think this is roughly equivalent to the algorithm sklearn is using:

@jit
def rbf_kernel(X, gamma)…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@riven314
Comment options

Answer selected by riven314
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants