Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,5 @@ extensions;https://github.com/microsoft/onnxruntime-extensions/archive/c24b7bab0
directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e
cudnn_frontend;https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.12.0.zip;7e733cfdc410d777b76122d64232499205589a96
dawn;https://github.com/google/dawn/archive/13c1635a14574ebb7116b56a69f5519301417fda.zip;0aadd28fc385cf7d657d5fc70a352372d2d3c76a
kleidiai;https://github.com/ARM-software/kleidiai/archive/refs/tags/v1.10.0.tar.gz;11b62149cb2514b3b9069cc435c3aa7a4e82b97a
kleidiai;https://github.com/ARM-software/kleidiai/archive/refs/tags/v1.15.0.tar.gz;62ccd24ab60bcef68766440fb42d79071ac2a5d2
duktape;https://github.com/svaarala/duktape/releases/download/v2.7.0/duktape-2.7.0.tar.xz;8200c8e417dbab7adcc12c4dbdef7651cfc55794
87 changes: 38 additions & 49 deletions onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <functional>
#include <unordered_map>

#include "kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.h"
#include "kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.h"
#include "kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.h"
#include "kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.h"
Expand Down Expand Up @@ -161,24 +162,7 @@ static bool CheckCapabilitiesSme(const MLAS_CONV_PARAMETERS* Parameters) {
return false;
}

//optimization checks - is the implementation optimal for the conv request

const auto n_step = kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa();
const auto m_step = kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa();

auto M = ComputeConvOutSize(Parameters->InputShape[0], ComputeKernelSize(Parameters->DilationShape[0],
Parameters->KernelShape[0]), Parameters->Padding[0], Parameters->StrideShape[0]) *
ComputeConvOutSize(Parameters->InputShape[1], ComputeKernelSize(Parameters->DilationShape[1],
Parameters->KernelShape[1]), Parameters->Padding[1], Parameters->StrideShape[1]);
auto N = Parameters->FilterCount;
auto K = Parameters->InputChannels * Parameters->KernelShape[0] * Parameters->KernelShape[1];

//Can use these variables to add other conditions as required
MLAS_UNREFERENCED_PARAMETER(M);
MLAS_UNREFERENCED_PARAMETER(K);
MLAS_UNREFERENCED_PARAMETER(m_step);
MLAS_UNREFERENCED_PARAMETER(n_step);

if (N == 1 || Parameters->KernelShape[0] < 3 || Parameters->KernelShape[1] < 3) {
KLEIDIAI_DEBUG_LOG("CheckCapabilitiesSme returning false on optimization checks.");
return false;
Expand Down Expand Up @@ -314,8 +298,8 @@ static void MultiThreadedLHSPackSme(MLAS_THREADPOOL* ThreadPool, const size_t ci
const size_t kw, const void * const* lhs_ptrs, std::byte* lhs_data,
const float* in_data,
const float* pad_ptr) {

auto m_step = kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa();
size_t m_step = ArmKleidiAI::UseSME2 ? kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa()
: kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();

// Minimize the kernel call count for the number of available threads
auto RequiredTiles = MlasDivRoundup(m, m_step);
Expand Down Expand Up @@ -399,7 +383,9 @@ static std::shared_ptr<const void*[]> LhsPtrFill(const size_t ci, const size_t i

const auto m = ComputeConvOutSize(ih, kh, padding, sh) * ComputeConvOutSize(iw, kw, padding, sw);

const auto m_step = kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa();
const auto m_step = ArmKleidiAI::UseSME2 ? kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa()
: kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();

const auto lhs_ptrs_k = kh * kw;
const auto lhs_ptrs_m = m_step * MlasDivRoundup(m, m_step);
auto lhs_ptrs = std::shared_ptr<const void*[]>(new const void*[lhs_ptrs_k * lhs_ptrs_m],
Expand Down Expand Up @@ -505,13 +491,13 @@ static std::unique_ptr<std::byte[]> LhsPackImageDataSme(const size_t ci, const s
}

static void ConvolveSme(const size_t co, //channels out
const size_t ci, //channels in
const size_t ih, //image height
const size_t iw, //image width
const size_t kh, //kernel height
const size_t kw, //kernel width
const size_t sh, //kernel stride height
const size_t sw, //kernel stride width
const size_t ci, //channels in
const size_t ih, //image height
const size_t iw, //image width
const size_t kh, //kernel height
const size_t kw, //kernel width
const size_t sh, //kernel stride height
const size_t sw, //kernel stride width
const size_t dilationh, //kernel dilation stride
const size_t dilationw, //kernel dilation stride
const size_t padding, //padding size
Expand All @@ -532,10 +518,12 @@ static void ConvolveSme(const size_t co, //channels out
const auto m = ComputeConvOutSize(ih, d_kh, padding, sh) *
ComputeConvOutSize(iw, d_kw, padding, sw);

auto n_step = kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa();
auto m_step = kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa();
size_t n_step = ArmKleidiAI::UseSME2 ? kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa()
: kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();
size_t m_step = ArmKleidiAI::UseSME2 ? kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa()
: kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();

//tile iteration dimensions
// tile iteration dimensions
std::array<size_t,3> dim;
dim[0] = 1; // B
dim[1] = MlasDivRoundup(m, m_step); // M
Expand Down Expand Up @@ -571,29 +559,23 @@ static void ConvolveSme(const size_t co, //channels out
auto lhs = LhsPackImageDataSme(ci, ih, iw, d_kh, d_kw, sh, sw, padding, in, ThreadPool);
auto rhs = RhsPackWeightsBiasSme(co, ci, kh, kw, dilationh, dilationw, weights, bias, ThreadPool);


MlasTrySimpleParallel(ThreadPool,
static_cast<ptrdiff_t>(dim[0]*dim[1]*dim[2]),
[&](ptrdiff_t tid)
{
MlasTrySimpleParallel(ThreadPool, static_cast<ptrdiff_t>(dim[0] * dim[1] * dim[2]), [&](ptrdiff_t tid) {
//compute B,M,N index from iteration index
//ptrdiff_t BIdx = tid / (dim[1] * dim[2]);
ptrdiff_t MIdx = (tid % (dim[1] * dim[2])) / dim[2];
ptrdiff_t NIdx = (tid % (dim[1] * dim[2])) % dim[2];

// Get rhs tile, B
const size_t rhs_packed_offset =
kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(NIdx*n_step,
d_kh*d_kw,ci);
const size_t rhs_packed_offset = ArmKleidiAI::UseSME2 ? kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(NIdx * n_step, d_kh * d_kw, ci)
: kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(NIdx * n_step, d_kh * d_kw, ci);

auto BTile = reinterpret_cast<const void*>(
reinterpret_cast<const std::byte*>(rhs.get()) + rhs_packed_offset
);

// Get lhs tile, A
const size_t lhs_packed_offset =
kai_get_lhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(MIdx*m_step,
d_kh*d_kw,ci);
const size_t lhs_packed_offset = ArmKleidiAI::UseSME2 ? kai_get_lhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(MIdx * m_step, d_kh * d_kw, ci)
: kai_get_lhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(MIdx * m_step, d_kh * d_kw, ci);

auto ATile = reinterpret_cast<const float*>(
reinterpret_cast<const std::byte*>(lhs.get()) + lhs_packed_offset
Expand All @@ -607,12 +589,19 @@ static void ConvolveSme(const size_t co, //channels out
MIdx * m_step * co * sizeof(float) +
NIdx * n_step * sizeof(float)];

KLEIDIAI_KERNEL_LOG("kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa"
<< " M=" << TileSizeM << " N=" << TileSizeN << " k_chunk_count=" << (d_kh*d_kw) << " k_chunk_length=" << ci);
kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(
TileSizeM, TileSizeN, d_kh*d_kw, ci, ATile, BTile, CTile, co * sizeof(float),
-std::numeric_limits<float>::max(), std::numeric_limits<float>::max()
);
if (ArmKleidiAI::UseSME2) {
KLEIDIAI_KERNEL_LOG("kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa" << " M=" << TileSizeM << " N=" << TileSizeN << " k_chunk_count=" << (d_kh * d_kw) << " k_chunk_length=" << ci);
kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(
TileSizeM, TileSizeN, d_kh * d_kw, ci, ATile, BTile, CTile, co * sizeof(float),
-std::numeric_limits<float>::max(), std::numeric_limits<float>::max()
);
} else {
KLEIDIAI_KERNEL_LOG("kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa" << " M=" << TileSizeM << " N=" << TileSizeN << " k_chunk_count=" << (d_kh * d_kw) << " k_chunk_length=" << ci);
kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(
TileSizeM, TileSizeN, d_kh * d_kw, ci, ATile, BTile, CTile, co * sizeof(float),
-std::numeric_limits<float>::max(), std::numeric_limits<float>::max()
);
}
});

if (result == tmp_mlas_aligned) {
Expand Down Expand Up @@ -712,11 +701,11 @@ ArmKleidiAI::MlasConv(
)
{
if(!CheckCapabilitiesSme(Parameters)){
//Fallback to Default Mlas
// Fallback to Default Mlas
return false;
};
ConvolveSme(Parameters->FilterCount, Parameters->InputChannels, // channel out, in
Parameters->InputShape[0], Parameters->InputShape[1], // image dimensions
Parameters->InputShape[0], Parameters->InputShape[1], // image dimensions
Parameters->KernelShape[0], Parameters->KernelShape[1], // kernel dimensions
Parameters->StrideShape[0], Parameters->StrideShape[1], // kernel stride dimensions
Parameters->DilationShape[0], Parameters->DilationShape[1], // kernel dilation
Expand Down
5 changes: 2 additions & 3 deletions onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

#pragma once

#include "mlasi.h"
#include "../mlasi.h"
#include <iostream>

// Fix to ensure compatibility with MSVC build
Expand Down Expand Up @@ -50,13 +50,12 @@
#endif

namespace ArmKleidiAI {

// By default we should try for SME2 first before falling back to SME.
inline const bool UseSME2 = MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2();

//
// Buffer packing routines.
//

size_t
MLASCALL
MlasGemmPackBSize(
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/core/mlas/lib/qgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,9 @@ MlasDynamicQGemmBatch (
MLAS_THREADPOOL* ThreadPool
) {
#if defined(USE_KLEIDIAI) && !defined(_MSC_VER)
//No fallback and putting in guards
if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){
ArmKleidiAI::MlasDynamicQGemmBatch(Shape, DataParams, BatchN, ThreadPool);
//No fallback and putting in guards. This implementation is SME2 specific.
if(ArmKleidiAI::UseSME2){
ArmKleidiAI::MlasDynamicQGemmBatch(Shape, DataParams, BatchN, ThreadPool);
}
#endif

Expand Down
Loading