Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,5 @@ extensions;https://github.com/microsoft/onnxruntime-extensions/archive/c24b7bab0
directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e
cudnn_frontend;https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.12.0.zip;7e733cfdc410d777b76122d64232499205589a96
dawn;https://github.com/google/dawn/archive/13c1635a14574ebb7116b56a69f5519301417fda.zip;0aadd28fc385cf7d657d5fc70a352372d2d3c76a
kleidiai;https://github.com/ARM-software/kleidiai/archive/refs/tags/v1.10.0.tar.gz;11b62149cb2514b3b9069cc435c3aa7a4e82b97a
kleidiai;https://github.com/ARM-software/kleidiai/archive/refs/tags/v1.15.0.tar.gz;62ccd24ab60bcef68766440fb42d79071ac2a5d2
duktape;https://github.com/svaarala/duktape/releases/download/v2.7.0/duktape-2.7.0.tar.xz;8200c8e417dbab7adcc12c4dbdef7651cfc55794
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/common/cpuid_info.h" // for CPUIDInfo::GetCPUIDInfo().HasArm_SME()
#include "core/common/cpuid_info.h" // for CPUIDInfo::GetCPUIDInfo().HasArm_SME2()
#include "core/common/narrow.h"
#include "core/common/safeint.h"
#include "core/mlas/inc/mlas.h"
Expand Down Expand Up @@ -213,9 +213,9 @@ class DynamicQuantizeMatMul final : public MatMulIntegerToFloatBase {
}
}

// Currently, MlasDynamicQGemmBatch() and associated functions require SME or else they are no-ops.
// Currently, MlasDynamicQGemmBatch() and associated functions require SME2 or else they are no-ops.
// We check that here too before attempting to use them.
if (!CPUIDInfo::GetCPUIDInfo().HasArm_SME()) {
if (!CPUIDInfo::GetCPUIDInfo().HasArm_SME2()) {
can_use_dynamic_quant_mlas_ = false;
}

Expand Down
87 changes: 38 additions & 49 deletions onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <functional>
#include <unordered_map>

#include "kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.h"
#include "kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.h"
#include "kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.h"
#include "kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.h"
Expand Down Expand Up @@ -161,24 +162,7 @@ static bool CheckCapabilitiesSme(const MLAS_CONV_PARAMETERS* Parameters) {
return false;
}

//optimization checks - is the implementation optimal for the conv request

const auto n_step = kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa();
const auto m_step = kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa();

auto M = ComputeConvOutSize(Parameters->InputShape[0], ComputeKernelSize(Parameters->DilationShape[0],
Parameters->KernelShape[0]), Parameters->Padding[0], Parameters->StrideShape[0]) *
ComputeConvOutSize(Parameters->InputShape[1], ComputeKernelSize(Parameters->DilationShape[1],
Parameters->KernelShape[1]), Parameters->Padding[1], Parameters->StrideShape[1]);
auto N = Parameters->FilterCount;
auto K = Parameters->InputChannels * Parameters->KernelShape[0] * Parameters->KernelShape[1];

//Can use these variables to add other conditions as required
MLAS_UNREFERENCED_PARAMETER(M);
MLAS_UNREFERENCED_PARAMETER(K);
MLAS_UNREFERENCED_PARAMETER(m_step);
MLAS_UNREFERENCED_PARAMETER(n_step);

if (N == 1 || Parameters->KernelShape[0] < 3 || Parameters->KernelShape[1] < 3) {
KLEIDIAI_DEBUG_LOG("CheckCapabilitiesSme returning false on optimization checks.");
return false;
Expand Down Expand Up @@ -314,8 +298,8 @@ static void MultiThreadedLHSPackSme(MLAS_THREADPOOL* ThreadPool, const size_t ci
const size_t kw, const void * const* lhs_ptrs, std::byte* lhs_data,
const float* in_data,
const float* pad_ptr) {

auto m_step = kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa();
size_t m_step = ArmKleidiAI::UseSME2 ? kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa()
: kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();

// Minimize the kernel call count for the number of available threads
auto RequiredTiles = MlasDivRoundup(m, m_step);
Expand Down Expand Up @@ -399,7 +383,9 @@ static std::shared_ptr<const void*[]> LhsPtrFill(const size_t ci, const size_t i

const auto m = ComputeConvOutSize(ih, kh, padding, sh) * ComputeConvOutSize(iw, kw, padding, sw);

const auto m_step = kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa();
const auto m_step = ArmKleidiAI::UseSME2 ? kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa()
: kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();

const auto lhs_ptrs_k = kh * kw;
const auto lhs_ptrs_m = m_step * MlasDivRoundup(m, m_step);
auto lhs_ptrs = std::shared_ptr<const void*[]>(new const void*[lhs_ptrs_k * lhs_ptrs_m],
Expand Down Expand Up @@ -505,13 +491,13 @@ static std::unique_ptr<std::byte[]> LhsPackImageDataSme(const size_t ci, const s
}

static void ConvolveSme(const size_t co, //channels out
const size_t ci, //channels in
const size_t ih, //image height
const size_t iw, //image width
const size_t kh, //kernel height
const size_t kw, //kernel width
const size_t sh, //kernel stride height
const size_t sw, //kernel stride width
const size_t ci, //channels in
const size_t ih, //image height
const size_t iw, //image width
const size_t kh, //kernel height
const size_t kw, //kernel width
const size_t sh, //kernel stride height
const size_t sw, //kernel stride width
const size_t dilationh, //kernel dilation stride
const size_t dilationw, //kernel dilation stride
const size_t padding, //padding size
Expand All @@ -532,10 +518,12 @@ static void ConvolveSme(const size_t co, //channels out
const auto m = ComputeConvOutSize(ih, d_kh, padding, sh) *
ComputeConvOutSize(iw, d_kw, padding, sw);

auto n_step = kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa();
auto m_step = kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa();
size_t n_step = ArmKleidiAI::UseSME2 ? kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa()
: kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();
size_t m_step = ArmKleidiAI::UseSME2 ? kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa()
: kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();

//tile iteration dimensions
// tile iteration dimensions
std::array<size_t,3> dim;
dim[0] = 1; // B
dim[1] = MlasDivRoundup(m, m_step); // M
Expand Down Expand Up @@ -571,29 +559,23 @@ static void ConvolveSme(const size_t co, //channels out
auto lhs = LhsPackImageDataSme(ci, ih, iw, d_kh, d_kw, sh, sw, padding, in, ThreadPool);
auto rhs = RhsPackWeightsBiasSme(co, ci, kh, kw, dilationh, dilationw, weights, bias, ThreadPool);


MlasTrySimpleParallel(ThreadPool,
static_cast<ptrdiff_t>(dim[0]*dim[1]*dim[2]),
[&](ptrdiff_t tid)
{
MlasTrySimpleParallel(ThreadPool, static_cast<ptrdiff_t>(dim[0] * dim[1] * dim[2]), [&](ptrdiff_t tid) {
//compute B,M,N index from iteration index
//ptrdiff_t BIdx = tid / (dim[1] * dim[2]);
ptrdiff_t MIdx = (tid % (dim[1] * dim[2])) / dim[2];
ptrdiff_t NIdx = (tid % (dim[1] * dim[2])) % dim[2];

// Get rhs tile, B
const size_t rhs_packed_offset =
kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(NIdx*n_step,
d_kh*d_kw,ci);
const size_t rhs_packed_offset = ArmKleidiAI::UseSME2 ? kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(NIdx * n_step, d_kh * d_kw, ci)
: kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(NIdx * n_step, d_kh * d_kw, ci);

auto BTile = reinterpret_cast<const void*>(
reinterpret_cast<const std::byte*>(rhs.get()) + rhs_packed_offset
);

// Get lhs tile, A
const size_t lhs_packed_offset =
kai_get_lhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(MIdx*m_step,
d_kh*d_kw,ci);
const size_t lhs_packed_offset = ArmKleidiAI::UseSME2 ? kai_get_lhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(MIdx * m_step, d_kh * d_kw, ci)
: kai_get_lhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(MIdx * m_step, d_kh * d_kw, ci);

auto ATile = reinterpret_cast<const float*>(
reinterpret_cast<const std::byte*>(lhs.get()) + lhs_packed_offset
Expand All @@ -607,12 +589,19 @@ static void ConvolveSme(const size_t co, //channels out
MIdx * m_step * co * sizeof(float) +
NIdx * n_step * sizeof(float)];

KLEIDIAI_KERNEL_LOG("kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa"
<< " M=" << TileSizeM << " N=" << TileSizeN << " k_chunk_count=" << (d_kh*d_kw) << " k_chunk_length=" << ci);
kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(
TileSizeM, TileSizeN, d_kh*d_kw, ci, ATile, BTile, CTile, co * sizeof(float),
-std::numeric_limits<float>::max(), std::numeric_limits<float>::max()
);
if (ArmKleidiAI::UseSME2) {
KLEIDIAI_KERNEL_LOG("kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa" << " M=" << TileSizeM << " N=" << TileSizeN << " k_chunk_count=" << (d_kh * d_kw) << " k_chunk_length=" << ci);
kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(
TileSizeM, TileSizeN, d_kh * d_kw, ci, ATile, BTile, CTile, co * sizeof(float),
-std::numeric_limits<float>::max(), std::numeric_limits<float>::max()
);
} else {
KLEIDIAI_KERNEL_LOG("kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa" << " M=" << TileSizeM << " N=" << TileSizeN << " k_chunk_count=" << (d_kh * d_kw) << " k_chunk_length=" << ci);
kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(
TileSizeM, TileSizeN, d_kh * d_kw, ci, ATile, BTile, CTile, co * sizeof(float),
-std::numeric_limits<float>::max(), std::numeric_limits<float>::max()
);
}
});

if (result == tmp_mlas_aligned) {
Expand Down Expand Up @@ -712,11 +701,11 @@ ArmKleidiAI::MlasConv(
)
{
if(!CheckCapabilitiesSme(Parameters)){
//Fallback to Default Mlas
// Fallback to Default Mlas
return false;
};
ConvolveSme(Parameters->FilterCount, Parameters->InputChannels, // channel out, in
Parameters->InputShape[0], Parameters->InputShape[1], // image dimensions
Parameters->InputShape[0], Parameters->InputShape[1], // image dimensions
Parameters->KernelShape[0], Parameters->KernelShape[1], // kernel dimensions
Parameters->StrideShape[0], Parameters->StrideShape[1], // kernel stride dimensions
Parameters->DilationShape[0], Parameters->DilationShape[1], // kernel dilation
Expand Down
5 changes: 2 additions & 3 deletions onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

#pragma once

#include "mlasi.h"
#include "../mlasi.h"
#include <iostream>

// Fix to ensure compatibility with MSVC build
Expand Down Expand Up @@ -50,13 +50,12 @@
#endif

namespace ArmKleidiAI {

// By default we should try for SME2 first before falling back to SME.
inline const bool UseSME2 = MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2();

//
// Buffer packing routines.
//

size_t
MLASCALL
MlasGemmPackBSize(
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/core/mlas/lib/qgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,9 @@ MlasDynamicQGemmBatch (
MLAS_THREADPOOL* ThreadPool
) {
#if defined(USE_KLEIDIAI) && !defined(_MSC_VER)
//No fallback and putting in guards
if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){
ArmKleidiAI::MlasDynamicQGemmBatch(Shape, DataParams, BatchN, ThreadPool);
//No fallback and putting in guards. This implementation is SME2 specific.
if(ArmKleidiAI::UseSME2){
ArmKleidiAI::MlasDynamicQGemmBatch(Shape, DataParams, BatchN, ThreadPool);
}
#endif

Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/test/mlas/unittest/test_dynamic_qgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ class MlasDynamicQgemmTest {

public:
void Test(size_t M, size_t N, size_t K, size_t BatchSize) {
// Currently, MlasDynamicQGemmBatch() and associated functions require SME or else they are no-ops.
if (!MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()) {
GTEST_SKIP() << "MlasDynamicQGemmBatch() requires ARM64 SME but it was not detected. Skipping test.";
// Currently, MlasDynamicQGemmBatch() and associated functions require SME2 or else they are no-ops.
if (!MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2()) {
GTEST_SKIP() << "MlasDynamicQGemmBatch() requires ARM64 SME2 but it was not detected. Skipping test.";
}

// Setup buffers for holding various data
Expand Down
Loading