x.at[].set() operation causes memcpyD2H,which is time-consuming #19137
Unanswered
Sun-Xiaohui
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi, JAX Team. Due to jax.array's immutability,
append
orinsert
operation in Numpy should be converted tox.at[].set()
in JAX. However, it cases Memcpy, which is time-consuming. Recently, I implemented the CRF's viterbi_decode function based on JAX, but the speed is slower than Pytorch even if applying@jax.jit
. Below is the main part of viterbi_decode:I got profiling result like this by using JAX profiler:

A large number of
MemcpyD2H
operation are brought in. In fact, I have achieved an obvious acceleration of model training(NER and LLM) based on @jax.jit, but testing is slower than Pytorch due tox.at[].set()
. How could I bypassx.at[].set()
or is there any way to speed it? Thanks for your time!Beta Was this translation helpful? Give feedback.
All reactions