Why is it very slow when I use jax.vmap on jax.numpy.convolve.(On CPU) #11694
Unanswered
xinyuxiao113
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I find a strange phenomenon: the time cost of jax.numpy.convolve(vmap on CPU) with respect to batch quantity is increasing much faster than linearly.

I can even use a for loop instead of jax.vmap to get a linearly increasing time consumption. So what's wrong with jax.vmap here? By the way, this problem is not exist on GPU.Beta Was this translation helpful? Give feedback.
All reactions