-
|
Hi, when calling cusolver geqr ignores batching, as I understand. Manually implementing QR using Householder reflections (as in cublas) improves performance by a factor of 100 on a100 with larger matrices, e.g 64x32768x256. However, there's already an optimized version implemented in cublas, namely geqrfBatched, which would be a better fit. Is it possible to ask jax or xla compiler to use cublas backend for QR decomposition? |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 3 replies
-
|
I think that JAX does actually call jax/jaxlib/gpu/solver_kernels_ffi.cc Lines 299 to 305 in ea22c3b although there is a heuristic to avoid batching for smaller batches of large matrices. This decision is made at runtime, not during lowering, so you won't see the cublas name embedded in the HLO. Hope this helps! |
Beta Was this translation helpful? Give feedback.
-
|
QR under Jax’s BenchmarkBenchmarked SVD, QR, and Cholesky on 64×64 matrices across different batch sizes. The results show that:
Total Time (milliseconds)
Scaling Ratio (relative to batch=1, ideal = 1.0)
AnalysisAn nsys profile of vmapped QR shows that the factorization itself does use a batched kernel (
Inspecting the repo confirms that there is currently no batched dispatch for jax/jaxlib/gpu/solver_kernels_ffi.cc Lines 375 to 381 in accb719 jax/jaxlib/gpu/solver_kernels_ffi.cc Lines 345 to 350 in accb719 This might explain the scaling behavior observed for batched QR compared to Cholesky. |
Beta Was this translation helpful? Give feedback.
I think that JAX does actually call
geqrfBatchedfor batched systems:jax/jaxlib/gpu/solver_kernels_ffi.cc
Lines 299 to 305 in ea22c3b
although there is a heuristic to avoid batching for smaller batches of large matrices. This decision is made at runtime, not during lowering, so you won't see the cublas name embedded in the HLO. Hope this helps!