-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Implement FP32 kleidiai Gemv #26302
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Implement FP32 kleidiai Gemv #26302
Conversation
|
/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline |
|
Azure Pipelines successfully started running 4 pipeline(s). |
|
/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline |
|
Azure Pipelines successfully started running 4 pipeline(s). |
|
/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline |
|
Azure Pipelines successfully started running 4 pipeline(s). |
| kai_matmul_clamp_f32_f32p_f32p_ukernel sgemm_gemm = GetKleidiAISGemmUKernel(); | ||
| kai_matmul_clamp_f32_f32_f32p_ukernel sgemm_gemv = GetKleidiAISGemvUKernel(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GetKleidiAIXUKernel() returns const&. do we need to make a copy here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| kai_matmul_clamp_f32_f32p_f32p_ukernel sgemm_gemm = GetKleidiAISGemmUKernel(); | |
| kai_matmul_clamp_f32_f32_f32p_ukernel sgemm_gemv = GetKleidiAISGemvUKernel(); | |
| const kai_matmul_clamp_f32_f32p_f32p_ukernel& sgemm_gemm = GetKleidiAISGemmUKernel(); | |
| const kai_matmul_clamp_f32_f32_f32p_ukernel& sgemm_gemv = GetKleidiAISGemvUKernel(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated to const in the latest push
onnxruntime/core/mlas/lib/qgemm.cpp
Outdated
| //No fallback and putting in guards | ||
| if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){ | ||
| ArmKleidiAI::MlasDynamicQGemmBatch(Shape, DataParams, BatchN, ThreadPool); | ||
| if(ArmKleidiAI::SMEInfo::CanUseSME2){ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there are other places that need to be updated, like:
| if (!CPUIDInfo::GetCPUIDInfo().HasArm_SME()) { |
| if (!MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()) { |
I might be missing some.
I think it would be worth making a helper function like MlasIsDynamicQGemmAvailable that has the appropriate checks and using that instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added in the updated checks in various places like these in the latest push
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be worth making a helper function like
MlasIsDynamicQGemmAvailablethat has the appropriate checks and using that instead.
to clarify, this was the main suggestion.
a3f4f5b to
e8ab1b1
Compare
|
/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline |
|
Azure Pipelines successfully started running 4 pipeline(s). |
| void Test(size_t M, size_t N, size_t K, size_t BatchSize) { | ||
| // Currently, MlasDynamicQGemmBatch() and associated functions require SME or else they are no-ops. | ||
| if (!MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()) { | ||
| if (!ArmKleidiAI::SMEInfo::CanUseSME2) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I guess the Gtest skip comment needs corresponding update too.
onnxruntime/core/mlas/lib/qgemm.cpp
Outdated
| //No fallback and putting in guards | ||
| if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){ | ||
| ArmKleidiAI::MlasDynamicQGemmBatch(Shape, DataParams, BatchN, ThreadPool); | ||
| if(ArmKleidiAI::SMEInfo::CanUseSME2){ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess after merging #26301, the checks looking for SME2 will go away (i.e.) it can be run on both SME1 and SME2 then ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes thats correct
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So one change I've made in the latest push is to remove this structure from our kleidi code specifically and put it into mlasi.h removing the armkleidiai namespacing around it, seemed like a sensible place to put it given that other similar code exists in terms of cpu features
4afc95c to
1d9b7c8
Compare
|
/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline |
|
Azure Pipelines successfully started running 4 pipeline(s). |
|
Could you please rebase this @JonathanC-ARM ? |
Signed-off-by: Jonathan Clohessy <[email protected]>
Signed-off-by: Jonathan Clohessy <[email protected]>
Signed-off-by: Jonathan Clohessy <[email protected]>
1d9b7c8 to
c9a507f
Compare
|
Hi @hariharans29 I've updated the branch now, thanks! |
Signed-off-by: Jonathan Clohessy <[email protected]>
c9a507f to
1ead7ca
Compare
|
/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline |
|
Azure Pipelines successfully started running 4 pipeline(s). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR implements FP32 GEMV (matrix-vector) optimizations for KleidiAI, addressing degenerate matrix multiplication cases where M=1 or N=1. The implementation introduces a microkernel interface abstraction to simplify code and remove conditional logic for SME/SME2 kernel selection.
Key changes:
- Added specialized GEMV path for M=1 and N=1 cases in FP32 SGEMM operations
- Introduced
SMEInfostruct with static boolean constants to replace scattered SME capability checks - Refactored microkernel selection using typedef interfaces instead of ternary operations
Reviewed Changes
Copilot reviewed 13 out of 13 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| test_fgemm_fixture.h | Added test cases for GEMV scenarios (M=1 and N=1) |
| test_dynamic_qgemm.cpp | Updated to use SMEInfo for SME2 capability check |
| qgemm.cpp | Replaced direct CPU feature checks with SMEInfo struct |
| platform.cpp | Updated SME availability check to use SMEInfo |
| mlasi.h | Added SMEInfo struct declaration and inline definitions |
| sgemm_kleidiai.cpp | Implemented GEMV functions with helper utilities and integrated with GEMM batch |
| mlasi_kleidiai.h | Removed local UseSME2 variable, added MlasFp32Gemv declaration |
| convolve_kleidiai.cpp | Updated to use SMEInfo for SME capability checks |
| kai_ukernel_interface.h | Added SGEMM/SGEMV ukernel interface declarations |
| kai_ukernel_interface.cpp | Implemented ukernel selection functions for SGEMM/SGEMV |
| convolve.cpp | Added SME availability guard and formatting fixes |
| dynamic_quantize_matmul.cc | Updated to use SMEInfo for SME2 capability check |
| deps.txt | Updated pytorch_cpuinfo dependency version |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Co-authored-by: Copilot <[email protected]>
Co-authored-by: Copilot <[email protected]>
|
Please rebase to include this : #26559 |
|
Hi @hariharans29 I've gone ahead and synced my fork now so this branch should include that change now , thanks for letting me know! |
|
/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline |
|
Azure Pipelines successfully started running 4 pipeline(s). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Copilot reviewed 13 out of 13 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if (N == 1 && TransB == CblasNoTrans) | ||
| { | ||
| g_kai_tls.gemv_lhs_row_tmp.resize(K); | ||
|
|
||
| for (size_t k = 0; k < K; ++k) { | ||
| g_kai_tls.gemv_lhs_row_tmp[k] = lhs_base[k * Data[b].ldb]; | ||
| } | ||
| lhs_base = g_kai_tls.gemv_lhs_row_tmp.data(); | ||
| } |
Copilot
AI
Nov 20, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The gather logic has an issue when M == 1 && N == 1. In this case:
- Line 158 sets
lhs_base = Data[b].A(taking the M == 1 path) - But line 184 checks
TransBand line 189 usesData[b].ldbto stride through the data - This is inconsistent: when M == 1, we're using A as the LHS vector, so we should check
TransAand useData[b].ldafor striding
When M == 1 && N == 1, if we take the M == 1 path (which the code does), the gather should check TransA and use lda since the LHS is now A, not B.
| // of the ONNX Runtime source tree. OpenMP may or may not be enabled in this | ||
| // configuration. | ||
| // | ||
| struct SMEInfo { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we guard this by #if defined(MLAS_TARGET_ARM64) as this is ARM specific?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure makes sense, will make the change
|
|
||
| // Boolean condition to determine if we can use SME2 | ||
| // By default we should try for SME2 first before falling back to SME. | ||
| inline const bool SMEInfo::CanUseSME2 = MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we avoid duplicate initializations by just moving the initializations out of the BUILD_MLAS_NO_ONNXRUNTIME guarded sections ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not 100% sure but I can check and I'd prefer that to be honest
| pthreadpool;https://github.com/google/pthreadpool/archive/dcc9f28589066af0dbd4555579281230abbf74dd.zip;533a77943203ef15ca608bcd9dbe2c94da7451d2 | ||
| pybind11;https://github.com/pybind/pybind11/archive/refs/tags/v2.13.6.zip;f780292da9db273c8ef06ccf5fd4b623624143e9 | ||
| pytorch_cpuinfo;https://github.com/pytorch/cpuinfo/archive/877328f188a3c7d1fa855871a278eb48d530c4c0.zip;9152d4bf6b8bde9f19b116de3bd8a745097ed9df | ||
| pytorch_cpuinfo;https://github.com/pytorch/cpuinfo/archive/de0ce7c7251372892e53ce9bc891750d2c9a4fd8.zip;c45b8d3619b9bccbd26dc5f657959aee38b18b7a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please remind me - why do we need this cupinfo dependency update for this PR ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think I made this change, it might be somehow an older version of the deps file. Will look into it.
| // Currently, MlasDynamicQGemmBatch() and associated functions require SME2 or else they are no-ops. | ||
| // We check that here too before attempting to use them. | ||
| if (!CPUIDInfo::GetCPUIDInfo().HasArm_SME2()) { | ||
| if (!SMEInfo::CanUseSME2) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we skip the qgemm.cpp / dynamic_quantize_matmul.cc / test_dynamic_qgemm.cpp changes in this PR ? #26598 is taking care of it with a new MLAS API for this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure I can revert this change altogether in favor of #26598
| // By default we should try for SME2 first before falling back to SME. | ||
| inline const bool UseSME2 = MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2(); | ||
|
|
||
| // |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Stray change ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes that seems to be the case, removed the above code but looks like I left the initial //
|
|
||
| bool | ||
| MLASCALL | ||
| MlasFp32Gemv( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider renaming to MlasGemvBatch to be consistent with MlasGemmBatch ? Thoughts on this ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like a sensible suggestion, doesn't really impact anything other than making things consistent. So I'd be happy to make the change
| // Attempt GEMV (M==1 or N==1) | ||
| if (M == 1 || N == 1) | ||
| { | ||
| if (ArmKleidiAI::MlasFp32Gemv(TransA, TransB, M, N, K, Data, BatchSize)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any scope for using multiple threads in the Gemv implementation (I see that the Gemv routing doesn't take in the ThreadPool param) ?
If there are plans to add the multi-threaded implementation in the future for Gemv, can you please add a TODO for that ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will likely add a todo it's not something that we tested to be honest, so would need to investigate whether threaded implementation provided any benefit to performance given that the kernels in question can handle the entire matrix in a single execution as is. But maybe there are cases where splitting the workload would lead to be benefits. But without testing I'm just speculating.
| if (M == 1 || N == 1) | ||
| { | ||
| if (ArmKleidiAI::MlasFp32Gemv(TransA, TransB, M, N, K, Data, BatchSize)) { | ||
| return true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just checking - In case the Gemv execution flow returns false for some reason, I am guessing you want to try KleidiAI's Gemm before falling back to MLAS ? That is how it is right now but I wanted to check if that is the intended flow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's the intention yes, attempt op and fallback to mlas if we cannot proceed for some reason.
|
/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline |
|
Azure Pipelines successfully started running 4 pipeline(s). |
Description
Implementation of special sgemm path which uses GEMV kernels in cases where M or N are 1
Additionally this pr introduces the usage of a microkernel interface which utilizes typedef's provided by KleidiAI such that we can simplify the code and remove things such as ternary operations for SME1 vs SME2 kernels
Indicative Performance
In Lieu of any production models where gemv was a large contributor of the network. I opted to create a mini model to test which contains thousands of randomized matmul variants. With a distribution of GEMV cases throughout

Using onnxruntime perf test I was able to half the total inference time vs mlas with this model

More Benchmarks to come shortly