Skip to content
Discussion options

You must be logged in to vote

I think that JAX does actually call geqrfBatched for batched systems:

if (batch > 1 && rows / batch <= 128 && cols / batch <= 128) {
SOLVER_BLAS_DISPATCH_IMPL(GeqrfBatchedImpl, batch, rows, cols, stream,
scratch, a, out, tau);
} else {
SOLVER_DISPATCH_IMPL(GeqrfImpl, batch, rows, cols, stream, scratch, a, out,
tau);
}

although there is a heuristic to avoid batching for smaller batches of large matrices. This decision is made at runtime, not during lowering, so you won't see the cublas name embedded in the HLO. Hope this helps!

Replies: 1 comment

Comment options

dfm
Sep 2, 2025
Collaborator

You must be logged in to vote
0 replies
Answer selected by lachinov
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants