Overhead of host_callback vs jax primitives? #7022
-
I'm in a situation where I need to repeatedly call external code inside of a loop. I need to take in a As a first pass, I just copied from def crude_version(x):
x_np = x.copy().copy()
return jax.device_put( extern_func(x_np) ) This has obvious drawbacks: I can't jit any function that contains I tried a second version using def hcb_version(x):
return host_callback(extern_func, x_np, out_shape) This works and lets me jit/vmap/pmap, but there is a drawback: it is over 10x slower than Two questions:
|
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 7 replies
-
Thanks for the questions! For your first question, the host_callback mechanism is indeed inefficient and is being revised. So I guess I'd say yes it's expected, but temporary. Ultimately its overheads should be made low (or the overheads of the API that replaces it). But that doesn't help you now! For the second question, writing a Primitive isn't necessarily an alternative; IIUC you'd still need a way for, say, a Do you have bindings to On CPU, where you wouldn't need to transfer buffers to and from the GPU, if you have Cython bindings it's not too hard to rig up the kind of CustomCall mechanism you'd want underneath a Primitive for good performance. (host_callback doesn't yet use CustomCalls, but it's being revised to use them!) You can see examples in lapack.pyx. If your external code accepted GPU arrays then you could do something like in cuda_prng_kernels.cc, cuda_prng_kernels.cu.cc, and cuda_rng.py, which use a GPU CustomCall. But if you need to transfer data to and from the GPU inside the call, I don't think we have a good example yet. The revised host_callback would be a good example, but with that you might be happy just using host_callback itself. See also this amazing "Extending JAX with custom C++ and CUDA code" tutorial. So to summarize:
|
Beta Was this translation helpful? Give feedback.
-
Hi Matt and co, I'm finding that host_callback still takes about 20ms per call. Is there any hope that host_callback will get faster soon? Or should I not count on it? Thanks, |
Beta Was this translation helpful? Give feedback.
Thanks for the questions!
For your first question, the host_callback mechanism is indeed inefficient and is being revised. So I guess I'd say yes it's expected, but temporary. Ultimately its overheads should be made low (or the overheads of the API that replaces it). But that doesn't help you now!
For the second question, writing a Primitive isn't necessarily an alternative; IIUC you'd still need a way for, say, a
jit
-compiled GPU program to call back onto the host to run your extern_func. So your Primitive's translation rule would need to solve the same problem that host_callback's machinery needs to solve. That's what I mean by it's not really an alternative: we still have the question …