-
In this gist, i compare jax vs numpy-based topk. Note that my implementation of topk relies on np.argpartition, which implements the introselect algorithm. Here's the timing results on my local machine:
Notice that the numpy implementation of topk (based on argpartition where kth= So given numpy np.argpartition is written in C++ and has a fairly efficient 2-pass algorithm in best case, what makes Jax so much faster than numpy? EDIT: Based on Jake's reply, I've added a publicly reproducible notebook version of the gist above: https://www.kaggle.com/xhlulu/numpy-argpartition-vs-jax-lax-top-k When running on a
So it is interesting to see that it is note as "clear cut" as above, but still a significant difference (38.5 vs 48.7). |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 5 replies
-
Side note: I've looked at the source code on jax but all I was able to find is that jax.lax binds a |
Beta Was this translation helpful? Give feedback.
-
Can you say more about the environment where you're running this? For example, if JAX is using a GPU and NumPy is using a CPU, that could easily lead to the kinds of performance differences you're seeing. Also, for JAX benchmarks, keep in mind the tips at JAX FAQ: benchmarking JAX code. In particular, your benchmarks don't account for JIT compilation or asynchronous dispatch, so if you're on a backend that supports it you may just be measuring compile + dispatch time rather than actual runtime. |
Beta Was this translation helpful? Give feedback.
If it's CPU you're curious about, you can find the implementation here: https://github.com/openxla/xla/blob/f868730d8fc557f9e26c983a015f6b63d5b241b4/xla/service/cpu/runtime_topk.cc#L27-L69
It looks like it's implemented via C++
std:partial_sort
.