Replies: 1 comment
-
I think the reason is that most jnp operations are jit-compiled by default, whereas lax operations are not. You can see this with: with jax.disable_jit():
%timeit lax.sin(1.).block_until_ready()
%timeit lax.cos(1.).block_until_ready()
%timeit jnp.sin(1.).block_until_ready()
%timeit jnp.cos(1.).block_until_ready() (Note |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
From my understanding, the operations from
jax.numpy
will internally call their associated primitive operations fromjax.lax
module, so their speed should be comparable.But I found operations from
jax.lax
are generally slower than that fromjax.numpy
(as attached below). Why there is such a discrepancy in speed?FYI: I am using JAX version 0.2.25 in Colab (CPU mode)
Beta Was this translation helpful? Give feedback.
All reactions