diff --git a/cmake/deps.txt b/cmake/deps.txt index 22c793705ddeb..e1870bf2df0cf 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -56,5 +56,5 @@ extensions;https://github.com/microsoft/onnxruntime-extensions/archive/c24b7bab0 directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e cudnn_frontend;https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.12.0.zip;7e733cfdc410d777b76122d64232499205589a96 dawn;https://github.com/google/dawn/archive/13c1635a14574ebb7116b56a69f5519301417fda.zip;0aadd28fc385cf7d657d5fc70a352372d2d3c76a -kleidiai;https://github.com/ARM-software/kleidiai/archive/refs/tags/v1.10.0.tar.gz;11b62149cb2514b3b9069cc435c3aa7a4e82b97a +kleidiai;https://github.com/ARM-software/kleidiai/archive/refs/tags/v1.15.0.tar.gz;62ccd24ab60bcef68766440fb42d79071ac2a5d2 duktape;https://github.com/svaarala/duktape/releases/download/v2.7.0/duktape-2.7.0.tar.xz;8200c8e417dbab7adcc12c4dbdef7651cfc55794 diff --git a/onnxruntime/core/mlas/lib/kleidiai/qgemm_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/qgemm_kleidiai.cpp index 1d682b372e2f5..28c6fa3037955 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/qgemm_kleidiai.cpp +++ b/onnxruntime/core/mlas/lib/kleidiai/qgemm_kleidiai.cpp @@ -11,10 +11,18 @@ #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.h" #include "mlasi_kleidiai.h" +// Thread-local reusable buffers to reduce allocation overhead across tiles. +struct KaiTlsBuffersQgemm { + std::vector lhs_packed; + std::vector lhs_base_table; +}; +static thread_local KaiTlsBuffersQgemm g_kai_tls_qgemm; + //Matmul with float output of dynamic quantized A and symmetric quantized B. size_t @@ -80,42 +88,135 @@ MLASCALL ArmKleidiAI::MlasDynamicQGemmBatch( const MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS& Shape, const MLAS_GEMM_DYN_QUANT_DATA_PARAMS* DataParams, - const size_t BatchN, + const size_t BatchSize, MLAS_THREADPOOL* ThreadPool ) { - for (auto b = BatchN; b > 0; --b,++DataParams) { - auto mr = kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); - auto kr = kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); - auto sr = kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); + const size_t mr = UseSME2 ? kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa() + : kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa(); + const size_t kr = UseSME2 ? kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa() + : kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa(); + const size_t sr = UseSME2 ? kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa() + : kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa(); - //TODO enable multi-threading for lhs packing and matmul - MLAS_UNREFERENCED_PARAMETER(ThreadPool); + size_t m_step = UseSME2 ? kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa() + : kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa(); + size_t n_step = UseSME2 ? kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa() + : kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa(); - //Dynamic Quantize A - lhs - auto lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(Shape.M, Shape.K, mr, kr, sr); - std::byte* lhs = nullptr; - std::unique_ptr fallback; - if (DataParams->Workspace && DataParams->WorkspaceSize >= lhs_size) { - lhs = static_cast(DataParams->Workspace); + if (Shape.M == 0 || Shape.N == 0 || Shape.K == 0) { + return; + } + if ((Shape.M < m_step || Shape.N < n_step) && !DataParams->PackedB) { + // Fallback to MLAS + ORT_ENFORCE(false, "ArmKleidiAI::MlasDynamicQGemmBatch(): unsupported small-shape case (M < m_step or N < n_step)"); + } + + //Dynamic Quantize A - lhs + const size_t LhsPackedStride = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(Shape.M, Shape.K, mr, kr, sr); + std::byte* LhsPackedData = nullptr; + + if (g_kai_tls_qgemm.lhs_packed.capacity() < LhsPackedStride * BatchSize) { + + g_kai_tls_qgemm.lhs_packed.reserve(LhsPackedStride * BatchSize); + } + g_kai_tls_qgemm.lhs_packed.resize(LhsPackedStride * BatchSize); + LhsPackedData = g_kai_tls_qgemm.lhs_packed.data(); + + //Per-batch table of lhs + if (g_kai_tls_qgemm.lhs_base_table.capacity() < BatchSize) { + + g_kai_tls_qgemm.lhs_base_table.reserve(BatchSize); + } + g_kai_tls_qgemm.lhs_base_table.resize(BatchSize); + // Capture the shared batch table pointer so worker threads use the same backing storage. + const std::byte** tls_lhs_base = g_kai_tls_qgemm.lhs_base_table.data(); + // B batches require no packing + // We have already decided the matmul variant we are using, before having values for M,N,K + MlasTrySimpleParallel(ThreadPool, BatchSize, [&](ptrdiff_t batch_idx) { + + std::byte* lhs = nullptr; + if (DataParams[batch_idx].Workspace && DataParams[batch_idx].WorkspaceSize >= LhsPackedStride) { + lhs = static_cast(DataParams[batch_idx].Workspace); } else { - fallback = std::make_unique(lhs_size); - lhs = fallback.get(); + lhs = &(LhsPackedData[LhsPackedStride * batch_idx]); } - KLEIDIAI_KERNEL_LOG("kai_run_lhs_quant_pack_qai8dxp_f32" - << " M="<< Shape.M << " K=" << Shape.K << " mr=" << mr << " kr=" << kr << " sr=" << sr << " m_idx_start=0"); - kai_run_lhs_quant_pack_qai8dxp_f32(Shape.M, Shape.K, mr, kr, sr, 0, DataParams->A, - Shape.K*sizeof(float), lhs); - - KLEIDIAI_KERNEL_LOG("kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa"); - kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa( - Shape.M, Shape.N, Shape.K, lhs, DataParams->PackedB, - DataParams->C, - Shape.N * sizeof(float), - sizeof(float), - -std::numeric_limits::max(), std::numeric_limits::max() + kai_run_lhs_quant_pack_qai8dxp_f32(Shape.M, Shape.K, mr, kr, sr, 0, DataParams[batch_idx].A, DataParams[batch_idx].lda*sizeof(float), lhs); + tls_lhs_base[batch_idx] = lhs; + }); + + // tile iteration dimensions + std::array dim; + dim[0] = BatchSize; // B + dim[1] = MlasDivRoundup(Shape.M, m_step); // M + dim[2] = MlasDivRoundup(Shape.N, n_step); // N + + // Minimize the kernel call count for the number of available threads + auto RequiredTiles = std::min(static_cast(MlasGetMaximumThreadCount(ThreadPool)), dim[0] * dim[1] * dim[2]); + + // scale required tiles over available tile processors + dim[1] = MlasDivRoundup(RequiredTiles * dim[1], dim[1] * dim[2]); + dim[2] = MlasDivRoundup(RequiredTiles * dim[2], dim[1] * dim[2]); + + // compute new step sizes + m_step *= MlasDivRoundup(MlasDivRoundup(Shape.M, dim[1]), m_step); + n_step *= MlasDivRoundup(MlasDivRoundup(Shape.N, dim[2]), n_step); + + // update tile iterations + dim[1] = MlasDivRoundup(Shape.M, m_step); + dim[2] = MlasDivRoundup(Shape.N, n_step); + + MlasTrySimpleParallel(ThreadPool, static_cast(dim[0] * dim[1] * dim[2]), [=](ptrdiff_t tid) { + + // compute B,M,N index from iteration index + ptrdiff_t BIdx = tid / (dim[1] * dim[2]); + ptrdiff_t MIdx = (tid % (dim[1] * dim[2])) / dim[2]; + ptrdiff_t NIdx = (tid % (dim[1] * dim[2])) % dim[2]; + + // Get rhs tile, B + const size_t rhs_packed_offset = + UseSME2 ? kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(NIdx * n_step, Shape.K) + : kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa(NIdx * n_step, Shape.K); + + const std::byte* B_base = reinterpret_cast(DataParams[BIdx].PackedB); + auto BTile = reinterpret_cast(B_base + rhs_packed_offset); + + // Get lhs tile, A + const size_t lhs_packed_offset = + UseSME2 ? kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(MIdx * m_step, Shape.K) + : kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa(MIdx * m_step, Shape.K); + + const std::byte* A_base = tls_lhs_base[BIdx]; // LhsPackedData + LhsPackedStride * BIdx; OR DataParams[batch_idx].Workspace; + auto ATile = reinterpret_cast(A_base + lhs_packed_offset); + + auto TileSizeM = (MIdx + 1) * m_step > Shape.M ? (Shape.M - MIdx * m_step) : m_step; + auto TileSizeN = (NIdx + 1) * n_step > Shape.N ? (Shape.N - NIdx * n_step) : n_step; + + float* dst_tile = reinterpret_cast( + reinterpret_cast(DataParams[BIdx].C) + + MIdx * m_step * DataParams[BIdx].ldc * sizeof(float) + + NIdx * n_step * sizeof(float) ); - } + + if (UseSME2) { + kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa( + TileSizeM, TileSizeN, Shape.K, ATile, BTile, + dst_tile, + DataParams[BIdx].ldc * sizeof(float), + sizeof(float), + -std::numeric_limits::max(), std::numeric_limits::max() + ); + } + else { + kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa( + TileSizeM, TileSizeN, Shape.K, ATile, BTile, + dst_tile, + DataParams[BIdx].ldc * sizeof(float), + sizeof(float), + -std::numeric_limits::max(), std::numeric_limits::max() + ); + } + }); } diff --git a/onnxruntime/test/mlas/unittest/test_dynamic_qgemm.cpp b/onnxruntime/test/mlas/unittest/test_dynamic_qgemm.cpp index 6d05e93f517ae..3baccc7faeea9 100644 --- a/onnxruntime/test/mlas/unittest/test_dynamic_qgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_dynamic_qgemm.cpp @@ -10,7 +10,10 @@ #include "test_util.h" #include "core/mlas/lib/mlasi.h" // for MLAS_CPUIDINFO -class MlasDynamicQgemmTest { +#include +#include + +class MlasDynamicQgemmTestBase { private: MatrixGuardBuffer buffer_a; MatrixGuardBuffer buffer_bf; @@ -18,15 +21,16 @@ class MlasDynamicQgemmTest { MatrixGuardBuffer buffer_c; MatrixGuardBuffer buffer_c_ref; - public: - void Test(size_t M, size_t N, size_t K, size_t BatchSize) { + protected: + void Run(size_t M, size_t N, size_t K, size_t BatchSize, + MLAS_THREADPOOL* threadpool, bool require_threadpool, const char* run_tag) { // Currently, MlasDynamicQGemmBatch() and associated functions require SME or else they are no-ops. - if (!MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()) { + if (!MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()) GTEST_SKIP() << "MlasDynamicQGemmBatch() requires ARM64 SME but it was not detected. Skipping test."; - } + if (require_threadpool && threadpool == nullptr) + GTEST_SKIP() << "Dynamic QGEMM threading path requested but no MLAS thread pool is available."; // Setup buffers for holding various data - float* A = buffer_a.GetBuffer(M * K * BatchSize); // Buffer for holding floating point version of weight matrix float* Bf = buffer_bf.GetBuffer(K * N * BatchSize); @@ -44,6 +48,10 @@ class MlasDynamicQgemmTest { // Quantize Bf → Bq and compute per-column scale and bias per batch std::vector> b_scale_batches(BatchSize, std::vector(N)); std::vector> b_bias_batches(BatchSize, std::vector(N, 0.0f)); + std::vector> a_quant_batches( + BatchSize, std::vector(M * K)); + std::vector> a_scale_batches(BatchSize, std::vector(M)); + std::vector> a_zero_point_batches(BatchSize, std::vector(M)); for (size_t b = 0; b < BatchSize; ++b) { for (size_t n = 0; n < N; ++n) { @@ -66,6 +74,42 @@ class MlasDynamicQgemmTest { } } + // Quantize A rows to match the dynamic quantization performed by the kernel. + for (size_t b = 0; b < BatchSize; ++b) { + for (size_t m = 0; m < M; ++m) { + float min_val = std::numeric_limits::max(); + float max_val = std::numeric_limits::lowest(); + for (size_t k = 0; k < K; ++k) { + float v = A[b * M * K + m * K + k]; + min_val = std::min(min_val, v); + max_val = std::max(max_val, v); + } + float rmin = std::min(0.0f, min_val); + float rmax = std::max(0.0f, max_val); + float inv_scale = (rmax == rmin) ? 1.0f : 255.0f / (rmax - rmin); + float scale = inv_scale ? 1.0f / inv_scale : 0.0f; + float descaled_min = rmin * inv_scale; + float descaled_max = rmax * inv_scale; + float zero_point_from_min_error = -128.0f + descaled_min; + float zero_point_from_max_error = 127.0f + descaled_max; + float zero_point = (zero_point_from_min_error + zero_point_from_max_error > 0.0f) + ? (-128.0f - descaled_min) + : (127.0f - descaled_max); + zero_point = std::clamp(zero_point, -128.0f, 127.0f); + int32_t zp = static_cast(std::nearbyint(zero_point)); + + a_scale_batches[b][m] = scale; + a_zero_point_batches[b][m] = zp; + + for (size_t k = 0; k < K; ++k) { + float v = A[b * M * K + m * K + k]; + int32_t q = static_cast(std::round(v * inv_scale)) + zp; + q = std::clamp(q, -128, 127); + a_quant_batches[b][m * K + k] = static_cast(q); + } + } + } + // Prepare kernel parameters MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS shape{M, N, K}; std::vector packed_b_storage(BatchSize * MlasDynamicQgemmPackBSize(N, K)); @@ -86,16 +130,17 @@ class MlasDynamicQgemmTest { params[b].PackedB = packed_b; } - // call MlasDynamicQGemmBatch Function - MlasDynamicQGemmBatch(shape, params.data(), BatchSize, nullptr); // Compute reference result for (size_t b = 0; b < BatchSize; ++b) { for (size_t m = 0; m < M; ++m) { for (size_t n = 0; n < N; ++n) { float sum = 0.0f; + const float a_scale = a_scale_batches[b][m]; + const int32_t a_zero_point = a_zero_point_batches[b][m]; for (size_t k = 0; k < K; ++k) { - float a = A[b * M * K + m * K + k]; + int32_t a_q = static_cast(a_quant_batches[b][m * K + k]); + float a = static_cast(a_q - a_zero_point) * a_scale; float bval = static_cast(Bq[b * K * N + k * N + n]) * b_scale_batches[b][n]; sum += a * bval; } @@ -104,45 +149,67 @@ class MlasDynamicQgemmTest { } } - // Validate results - for (size_t i = 0; i < M * N * BatchSize; ++i) { - float abs_c_ref = std::abs(CRef[i]); - float dynamic_rel_tol = (K <= 4) ? 0.05f : 0.03f; - float rel_tol = dynamic_rel_tol * std::max(abs_c_ref, 1.0f); - float abs_tol = 3.0f; - float allowed = std::max(rel_tol, abs_tol); - float diff = std::abs(C[i] - CRef[i]); - ASSERT_LE(diff, allowed); - } - } + std::fill(C, C + M * N * BatchSize, 0.0f); + MlasDynamicQGemmBatch(shape, params.data(), BatchSize, threadpool); - static const char* GetTestSuiteName() { - return "DynamicQgemm"; + // Validate results + auto validate = [&](const char* tag) { + SCOPED_TRACE(tag); + for (size_t i = 0; i < M * N * BatchSize; ++i) { + float abs_c_ref = std::abs(CRef[i]); + float dynamic_rel_tol = (K <= 4) ? 0.05f : 0.03f; + float rel_tol = dynamic_rel_tol * std::max(abs_c_ref, 1.0f); + float abs_tol = 3.0f; + float allowed = std::max(rel_tol, abs_tol); + float diff = std::abs(C[i] - CRef[i]); + ASSERT_LE(diff, allowed); + } + }; + + validate(run_tag); } }; -class DynamicQgemmExecuteTest : public MlasTestFixture { + class MlasDynamicQgemmSingleThreadTest : public MlasDynamicQgemmTestBase { + public: + void Test(size_t M, size_t N, size_t K, size_t BatchSize) { + Run(M, N, K, BatchSize, /*threadpool*/ nullptr, /*require_threadpool*/ false, "SingleThread"); + } + static const char* GetTestSuiteName() { return "DynamicQgemmSingleThread"; } + }; + + class MlasDynamicQgemmThreadPoolTest : public MlasDynamicQgemmTestBase { + public: + void Test(size_t M, size_t N, size_t K, size_t BatchSize) { + MLAS_THREADPOOL* tp = GetMlasThreadPool(); + if (!tp) GTEST_SKIP() << "Mlas thread pool not available"; + Run(M, N, K, BatchSize, tp, /*require_threadpool*/ true, "ThreadPool"); + } + static const char* GetTestSuiteName() { return "DynamicQgemmThreaded"; } + }; + +template +class DynamicQgemmExecuteTest : public MlasTestFixture { public: DynamicQgemmExecuteTest(size_t M, size_t N, size_t K, size_t BatchSize) : M_(M), N_(N), K_(K), BatchSize_(BatchSize) {} void TestBody() override { - this->mlas_tester->Test(M_, N_, K_, BatchSize_); + MlasTestFixture::mlas_tester->Test(M_, N_, K_, BatchSize_); } static size_t RegisterSingleTest(size_t M, size_t N, size_t K, size_t BatchSize) { std::stringstream ss; ss << "M" << M << "_N" << N << "_K" << K << "_B" << BatchSize; - std::string test_name = ss.str(); testing::RegisterTest( - MlasDynamicQgemmTest::GetTestSuiteName(), + TMlasTester::GetTestSuiteName(), test_name.c_str(), nullptr, test_name.c_str(), __FILE__, __LINE__, - [=]() -> MlasTestFixture* { + [=]() -> MlasTestFixture* { return new DynamicQgemmExecuteTest(M, N, K, BatchSize); }); @@ -166,7 +233,11 @@ class DynamicQgemmExecuteTest : public MlasTestFixture { size_t M_, N_, K_, BatchSize_; }; -static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { - return DynamicQgemmExecuteTest::RegisterAll(is_short_execute); +static UNUSED_VARIABLE bool added_single = AddTestRegister([](bool is_short_execute) { + return DynamicQgemmExecuteTest::RegisterAll(is_short_execute); +}); + +static UNUSED_VARIABLE bool added_threaded = AddTestRegister([](bool is_short_execute) { + return DynamicQgemmExecuteTest::RegisterAll(is_short_execute); }); #endif