-
Hi, when calling cusolver geqr ignores batching, as I understand. Manually implementing QR using Householder reflections (as in cublas) improves performance by a factor of 100 on a100 with larger matrices, e.g 64x32768x256. However, there's already an optimized version implemented in cublas, namely geqrfBatched, which would be a better fit. Is it possible to ask jax or xla compiler to use cublas backend for QR decomposition? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
I think that JAX does actually call jax/jaxlib/gpu/solver_kernels_ffi.cc Lines 299 to 305 in ea22c3b although there is a heuristic to avoid batching for smaller batches of large matrices. This decision is made at runtime, not during lowering, so you won't see the cublas name embedded in the HLO. Hope this helps! |
Beta Was this translation helpful? Give feedback.
I think that JAX does actually call
geqrfBatched
for batched systems:jax/jaxlib/gpu/solver_kernels_ffi.cc
Lines 299 to 305 in ea22c3b
although there is a heuristic to avoid batching for smaller batches of large matrices. This decision is made at runtime, not during lowering, so you won't see the cublas name embedded in the HLO. Hope this helps!