-
I am running a benchmark between Python / Numpy v.s. JAX implementation on simple linear regression. The implementation is to iterate the weights
In JAX implementation, it is pretty much the same but just to wrap the iteration into a closure function
It is staggering that the JAX implementation runs much slower than Numpy implementation. For example, the training size is 10k and the regression dimension is 100, i.e. the most intensive operation is a matrix multiplication of (10k, 100) and (100, 1)
The JAX runtime (13s on average) is 4 times slower than the numpy runtime (3s on average) in CPU. The JAX version used is 0.3.14, and the benchmark colab notebook can be found in below. https://colab.research.google.com/drive/17FY2z3Og7a36Ub4pal2vSdQ7FIzP9I34?usp=sharing |
Beta Was this translation helpful? Give feedback.
Replies: 6 comments 6 replies
-
Alternatively, I reimplemented the JAX one with
|
Beta Was this translation helpful? Give feedback.
-
Have you read through FAQ: Is JAX Faster Than Numpy? The answer is not always yes, and there is some information there about JAX best practices, when you can expect JAX to be faster, as well as tips on getting accurate benchmarks. |
Beta Was this translation helpful? Give feedback.
-
Thanks for your prompt response. Yes I am also aware of asynchronous dispatch in JAX. Same for the documentation, my benchmark measures only the JAX runtime, and does not include the transfer time and compilation time. Also, to amortize the JAX overhead in CPU, I increased the size of the matrices, but struggled to understand the runtime difference grows linearly with the size of the matrices. |
Beta Was this translation helpful? Give feedback.
-
Thanks so much for your detailed explanation! It makes total sense to me I may be comparing with different numerical libraries in background actually. For example, openblas was used in the numpy installed in the benchmark machine
while Tensorflow uses intel MKL (or Eigen?). Do you know how to find out which numerical library is used in JAX? Or any documentation / source code I can refer to? |
Beta Was this translation helpful? Give feedback.
-
I wonder if this is related to this issue. I find that just by the Python for-loop for jitted @jit
def _body(_, val):
# Forward pass [NX1] · [1X1] = [NX1]
W, b = val
# Loss
error = y_train - (x_train @ W + b)
loss = (error.T @ error) / N
# Backpropagation
dW = -(2/N) * (x_train.T @ error)
db = -(2/N) * np.sum(error)
# Update weights
W += -learning_rate * dW
b += -learning_rate * db
return (W, b)
for _ in range(num_epochs):
W, b = _body(_, (W, b)) |
Beta Was this translation helpful? Give feedback.
-
Thanks @anh-tong and I confirm the performance is much better without jax fori_loop / scan. It seems the root cause coming from it.
The implementation suggested by @anh-tong does not require |
Beta Was this translation helpful? Give feedback.
I wonder if this is related to this issue.
I find that just by the Python for-loop for jitted
_body
can improve the performance (2.18s)