Replies: 1 comment 2 replies
-
If you're interested in microbenchmarks of JAX operations, please take a look at https://jax.readthedocs.io/en/latest/faq.html#benchmarking-jax-code. In particular, I think JAX's asynchronous dispatch may be muddying your results here. But that aside, it wouldn't surprise me if your JAX implementation is much slower than the Python implementation, because in the Python version your deque is storing pointers to an already allocated array, whereas in your JAX implementation in each step you are copying the contents of one array element-by-element into another. |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I am very new to jax so please bear with me. I am currently working on a circular buffer implementation in jax that I can safely pass to jitable functions. For now, the implementation looks as follows:
Next, I compared the speed of my jax buffer to
collections.deque
.From running this script, I obtain the following output.
The difference in speed between
fill_jax_buffer()
andjitted_fill_jax_buffer()
seems reasonable. However, I am unsure about the low runtime offill_deque_mem()
, especially sincedeque
-objects aren't restricted to hold onlychex.Array
-objects. What ways are there for me to accelerate my buffer implementation?I conducted a similar test as above but first generated a random
ndarray
calledelements
of shape(n_elements, 3)
. In iterationi
, I then appendedelements.at[i,:].get()
to the buffers. In this test thejitted_fill_jax_buffer()
outperformedfill_deque_mem()
. However, I believe that this could also be due to the fact that the outer@jit
is able to further optimise theelements.at[i,:].get()
calls. What do you think?Beta Was this translation helpful? Give feedback.
All reactions