Seeking feedback from JAX developers. Is my JAX to CuPy bridge cursed or blessed? #34732
Unanswered
steppi
asked this question in
Show and tell
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I am a maintainer of SciPy and the library xsf a header only library of special function scalar kernels for use on both CPU and GPU. I've been working on support for alternative Python array API standard backends in SciPy. I'm often asked why the xsf special function kernels can't be used for all backends if no other implementation is available, and usually I answer something like "well JAX does its own thing with XLA which allows stuff like fusing across function boundaries and automatically supporting autograd, so I don't think you can just use a straight CUDA implementation."
Of course, on CPU, one can delegate to a NumPy implementation using the
pure_callbackmechanism, which is what xpx.lazy_apply, used heavily in SciPy, does. I've been wanting to be able to do the same kind of thing with delegation to a CuPy implementation without having to transfer to host. I've been tinkering with the JAX FFI, and have worked out a prototype implementation which allows this. It can be found here: https://github.com/steppi/jax_cupy_bridge/tree/main.I'm curious whether my use of the FFI is within the bounds of what the JAX developers intended and if my approach is likely to continue working as the FFI continues to evolve, or if I've just happened to stumble upon something that happens to work now, but will be brittle. Being able to delegate from JAX to CuPy under the JIT will be very helpful for SciPy maintainers trying to round out support for the parts of SciPy's public API which are in-scope for array API standard support. This would allow us to use already existing implementations rather than having to rely on JAX developers to write XLA implementations for everything we'd need. One way or the other, I'd like to be able to delegate from JAX to CuPy under the JIT, so hope the approach I've taken will remain viable. If any JAX devs are aware of other approaches that may work better, I would be happy to learn of them.
Tour
Here's a brief guide to how the
jax_cupy_bridgeworks:There is a C extension here, https://github.com/steppi/jax_cupy_bridge/blob/main/src/jax_cupy_bridge/_bridge.cpp, which uses the header only library from XLA. It uses the Python C API to access the JAX arrays and JAX's cuda stream in CuPy, importing and using the code from here, https://github.com/steppi/jax_cupy_bridge/blob/main/src/jax_cupy_bridge/_core.py, within the C extension. The code in
_core.pyis where the JAX buffers get treated as CuPy arrays and a CuPy function is applied. Since this relies on the Python interpreter to handle the CuPy side of things,_bridge.cppneeds to hold the GIL; I'm not sure if there might be some unintended consequences of this. It's the aspect I most worry might make this idea cursed.The Python frontend is used here, https://github.com/steppi/jax_cupy_bridge/blob/main/src/jax_cupy_bridge/_lazy.py, and this is where the handler is registered. Currently the CuPy function is passed to the handler using
id(func), which gives the address under CPython, but probably shouldn't be done in production code, since this is just an implementation detail. I'm just doing that because this is a proof of concept and plan to iterate if my approach is found to not be completely cursed.There are some benchmarks here: https://github.com/steppi/jax_cupy_bridge/blob/main/benchmarks/benchmark1.py and https://github.com/steppi/jax_cupy_bridge/blob/main/benchmarks/benchmark2.py which verify that what I'm calling
cupy_lazy_applyworks under the JIT. The results are encouraging. On my Ubuntu workstation with an NVIDIA A4000 GPU I saw the following results forbenchmark1.pywhich evaluates
sin(erf(x))wherexis a random array of size10^7. Using the JAX-CuPy bridge is slower than using CuPy due to added overhead, but it's not terrible. The second benchmark involves calculatinglog(betainc(a, b, x))wherea,b, andxare random arrays of size10^6.Beyond potential performance gains from fusing, I think native JAX consistently wins here because
xsfdoes not have native float32 implementations ofbetaincanderfyet, so it's doing the computations infloat64which is not well-supported by my consumer grade GPU.I'd be happy to hear thoughts on the bridge I've written. Obviously, this would just be for use as a last-resort if there is no native JAX implementation to use. I suppose gradients could be handled through setting up custom vjps, but for now the intention here is just to be able to support more of the SciPy public API using JAX as the backend, on GPU, and while allowing the JIT. Thanks for adding the FFI, it makes it fairly straightforward to call out to external code.
Beta Was this translation helpful? Give feedback.
All reactions