Skip to content
Discussion options

You must be logged in to vote

I think that JAX does actually call geqrfBatched for batched systems:

if (batch > 1 && rows / batch <= 128 && cols / batch <= 128) {
SOLVER_BLAS_DISPATCH_IMPL(GeqrfBatchedImpl, batch, rows, cols, stream,
scratch, a, out, tau);
} else {
SOLVER_DISPATCH_IMPL(GeqrfImpl, batch, rows, cols, stream, scratch, a, out,
tau);
}

although there is a heuristic to avoid batching for smaller batches of large matrices. This decision is made at runtime, not during lowering, so you won't see the cublas name embedded in the HLO. Hope this helps!

Replies: 2 comments 3 replies

Comment options

You must be logged in to vote
1 reply
@yantaow
Comment options

Answer selected by lachinov
Comment options

You must be logged in to vote
2 replies
@fliingelephant
Comment options

@fliingelephant
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
4 participants