Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,5 @@ extensions;https://github.com/microsoft/onnxruntime-extensions/archive/c24b7bab0
directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e
cudnn_frontend;https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.12.0.zip;7e733cfdc410d777b76122d64232499205589a96
dawn;https://github.com/google/dawn/archive/13c1635a14574ebb7116b56a69f5519301417fda.zip;0aadd28fc385cf7d657d5fc70a352372d2d3c76a
kleidiai;https://github.com/ARM-software/kleidiai/archive/refs/tags/v1.10.0.tar.gz;11b62149cb2514b3b9069cc435c3aa7a4e82b97a
kleidiai;https://github.com/ARM-software/kleidiai/archive/refs/tags/v1.15.0.tar.gz;62ccd24ab60bcef68766440fb42d79071ac2a5d2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this update in the KAI version from 1.10 to 1.15, can SME/SME2 detection be enabled on Windows too to leverage the kernels ?

https://github.com/microsoft/onnxruntime/pull/25187/files#r2223006773
https://github.com/microsoft/onnxruntime/pull/25760/files#r2325260570

duktape;https://github.com/svaarala/duktape/releases/download/v2.7.0/duktape-2.7.0.tar.xz;8200c8e417dbab7adcc12c4dbdef7651cfc55794
157 changes: 129 additions & 28 deletions onnxruntime/core/mlas/lib/kleidiai/qgemm_kleidiai.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,18 @@
#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon.h"

#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h"
#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa.h"
#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.h"

#include "mlasi_kleidiai.h"

// Thread-local reusable buffers to reduce allocation overhead across tiles.
struct KaiTlsBuffersQgemm {
std::vector<std::byte> lhs_packed;
std::vector<const std::byte*> lhs_base_table;
};
static thread_local KaiTlsBuffersQgemm g_kai_tls_qgemm;

//Matmul with float output of dynamic quantized A and symmetric quantized B.

size_t
Expand Down Expand Up @@ -80,42 +88,135 @@ MLASCALL
ArmKleidiAI::MlasDynamicQGemmBatch(
const MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS& Shape,
const MLAS_GEMM_DYN_QUANT_DATA_PARAMS* DataParams,
const size_t BatchN,
const size_t BatchSize,
MLAS_THREADPOOL* ThreadPool
) {
for (auto b = BatchN; b > 0; --b,++DataParams) {
auto mr = kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa();
auto kr = kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa();
auto sr = kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa();

const size_t mr = UseSME2 ? kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa()
: kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa();
const size_t kr = UseSME2 ? kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa()
: kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa();
const size_t sr = UseSME2 ? kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa()
: kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa();

//TODO enable multi-threading for lhs packing and matmul
MLAS_UNREFERENCED_PARAMETER(ThreadPool);
size_t m_step = UseSME2 ? kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa()
: kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa();
size_t n_step = UseSME2 ? kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa()
: kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa();

//Dynamic Quantize A - lhs
auto lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(Shape.M, Shape.K, mr, kr, sr);
std::byte* lhs = nullptr;
std::unique_ptr<std::byte[]> fallback;

if (DataParams->Workspace && DataParams->WorkspaceSize >= lhs_size) {
lhs = static_cast<std::byte*>(DataParams->Workspace);
if (Shape.M == 0 || Shape.N == 0 || Shape.K == 0) {
return;
}
if ((Shape.M < m_step || Shape.N < n_step) && !DataParams->PackedB) {
// Fallback to MLAS
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is no fallback implementation of MlasDynamicQGemmBatch().

#if defined(USE_KLEIDIAI) && !defined(_MSC_VER)
//No fallback and putting in guards
if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){
ArmKleidiAI::MlasDynamicQGemmBatch(Shape, DataParams, BatchN, ThreadPool);
}
#endif
MLAS_UNREFERENCED_PARAMETER(Shape);
MLAS_UNREFERENCED_PARAMETER(DataParams);
MLAS_UNREFERENCED_PARAMETER(BatchN);
MLAS_UNREFERENCED_PARAMETER(ThreadPool);

if we get to this point, the computation should happen or (maybe less preferably) it should be a hard error.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will investigate the fallback case further and try to provide better implementation.
Until then, would like to get your opinion on using ORT_ENFORCE

ORT_ENFORCE(false, "ArmKleidiAI::MlasDynamicQGemmBatch(): unsupported small-shape case (M < m_step or N < n_step)");

Copy link
Member

@hariharans29 hariharans29 Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we instead implement @edgchen1's suggestion in the other PR: #26302 (comment) to have a universal check that can be used in all places to check if MLAS supports QGemm for that problem shape, platform, etc. ?

Also since we have a check on the M dimension, this might need some thinking - In the current setup, we turn off MLAS usage for QGemm in PrePack() if we don't detect SME or the weight's shape don't match requirements in PrePack(). See here and here. The M dimension won't be known in PrePack().

Just curious - what would happen if the M was < m_step ? Would there be a crash or would the perf be sub-optimal ? If so, we need to add a runtime check in the CPU kernel's Run() function which means we may need to perform pre-packing for both KAI and the "regular" path. See here.

ORT_ENFORCE(false, "ArmKleidiAI::MlasDynamicQGemmBatch(): unsupported small-shape case (M < m_step or N < n_step)");
}

//Dynamic Quantize A - lhs
const size_t LhsPackedStride = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(Shape.M, Shape.K, mr, kr, sr);
std::byte* LhsPackedData = nullptr;

if (g_kai_tls_qgemm.lhs_packed.capacity() < LhsPackedStride * BatchSize) {

g_kai_tls_qgemm.lhs_packed.reserve(LhsPackedStride * BatchSize);
}
g_kai_tls_qgemm.lhs_packed.resize(LhsPackedStride * BatchSize);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we just do the resizing directly instead of reserve + resize ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, reserve() + resize() or using only resize() cases both end up with one allocation + one initialisation. But somehow there is a very very little performance difference in the case allocation and initialisation separated or done at once with resize(). (after: is the case reserve() calls removed and only resize() is used.)
ort_ops_compare_2_thread_before_2025-10-29_13-08-56_vs_2_thread_after_2025-10-29_13-32-05

LhsPackedData = g_kai_tls_qgemm.lhs_packed.data();

//Per-batch table of lhs
if (g_kai_tls_qgemm.lhs_base_table.capacity() < BatchSize) {

g_kai_tls_qgemm.lhs_base_table.reserve(BatchSize);
}
g_kai_tls_qgemm.lhs_base_table.resize(BatchSize);
// Capture the shared batch table pointer so worker threads use the same backing storage.
const std::byte** tls_lhs_base = g_kai_tls_qgemm.lhs_base_table.data();
// B batches require no packing
// We have already decided the matmul variant we are using, before having values for M,N,K
MlasTrySimpleParallel(ThreadPool, BatchSize, [&](ptrdiff_t batch_idx) {

std::byte* lhs = nullptr;
if (DataParams[batch_idx].Workspace && DataParams[batch_idx].WorkspaceSize >= LhsPackedStride) {
lhs = static_cast<std::byte*>(DataParams[batch_idx].Workspace);
} else {
fallback = std::make_unique<std::byte[]>(lhs_size);
lhs = fallback.get();
lhs = &(LhsPackedData[LhsPackedStride * batch_idx]);
}

KLEIDIAI_KERNEL_LOG("kai_run_lhs_quant_pack_qai8dxp_f32"
<< " M="<< Shape.M << " K=" << Shape.K << " mr=" << mr << " kr=" << kr << " sr=" << sr << " m_idx_start=0");
kai_run_lhs_quant_pack_qai8dxp_f32(Shape.M, Shape.K, mr, kr, sr, 0, DataParams->A,
Shape.K*sizeof(float), lhs);

KLEIDIAI_KERNEL_LOG("kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa");
kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(
Shape.M, Shape.N, Shape.K, lhs, DataParams->PackedB,
DataParams->C,
Shape.N * sizeof(float),
sizeof(float),
-std::numeric_limits<float>::max(), std::numeric_limits<float>::max()
kai_run_lhs_quant_pack_qai8dxp_f32(Shape.M, Shape.K, mr, kr, sr, 0, DataParams[batch_idx].A, DataParams[batch_idx].lda*sizeof(float), lhs);
tls_lhs_base[batch_idx] = lhs;
});

// tile iteration dimensions
std::array<size_t, 3> dim;
dim[0] = BatchSize; // B
dim[1] = MlasDivRoundup(Shape.M, m_step); // M
dim[2] = MlasDivRoundup(Shape.N, n_step); // N

// Minimize the kernel call count for the number of available threads
auto RequiredTiles = std::min(static_cast<size_t>(MlasGetMaximumThreadCount(ThreadPool)), dim[0] * dim[1] * dim[2]);

// scale required tiles over available tile processors
dim[1] = MlasDivRoundup(RequiredTiles * dim[1], dim[1] * dim[2]);
dim[2] = MlasDivRoundup(RequiredTiles * dim[2], dim[1] * dim[2]);

// compute new step sizes
m_step *= MlasDivRoundup(MlasDivRoundup(Shape.M, dim[1]), m_step);
n_step *= MlasDivRoundup(MlasDivRoundup(Shape.N, dim[2]), n_step);

// update tile iterations
dim[1] = MlasDivRoundup(Shape.M, m_step);
dim[2] = MlasDivRoundup(Shape.N, n_step);

MlasTrySimpleParallel(ThreadPool, static_cast<ptrdiff_t>(dim[0] * dim[1] * dim[2]), [=](ptrdiff_t tid) {

// compute B,M,N index from iteration index
ptrdiff_t BIdx = tid / (dim[1] * dim[2]);
ptrdiff_t MIdx = (tid % (dim[1] * dim[2])) / dim[2];
ptrdiff_t NIdx = (tid % (dim[1] * dim[2])) % dim[2];

// Get rhs tile, B
const size_t rhs_packed_offset =
UseSME2 ? kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(NIdx * n_step, Shape.K)
: kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa(NIdx * n_step, Shape.K);

const std::byte* B_base = reinterpret_cast<const std::byte*>(DataParams[BIdx].PackedB);
auto BTile = reinterpret_cast<const void*>(B_base + rhs_packed_offset);

// Get lhs tile, A
const size_t lhs_packed_offset =
UseSME2 ? kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(MIdx * m_step, Shape.K)
: kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa(MIdx * m_step, Shape.K);

const std::byte* A_base = tls_lhs_base[BIdx]; // LhsPackedData + LhsPackedStride * BIdx; OR DataParams[batch_idx].Workspace;
auto ATile = reinterpret_cast<const std::byte*>(A_base + lhs_packed_offset);

auto TileSizeM = (MIdx + 1) * m_step > Shape.M ? (Shape.M - MIdx * m_step) : m_step;
auto TileSizeN = (NIdx + 1) * n_step > Shape.N ? (Shape.N - NIdx * n_step) : n_step;

float* dst_tile = reinterpret_cast<float*>(
reinterpret_cast<std::byte*>(DataParams[BIdx].C) +
MIdx * m_step * DataParams[BIdx].ldc * sizeof(float) +
NIdx * n_step * sizeof(float)
);
}

if (UseSME2) {
kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(
TileSizeM, TileSizeN, Shape.K, ATile, BTile,
dst_tile,
DataParams[BIdx].ldc * sizeof(float),
sizeof(float),
-std::numeric_limits<float>::max(), std::numeric_limits<float>::max()
);
}
else {
kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa(
TileSizeM, TileSizeN, Shape.K, ATile, BTile,
dst_tile,
DataParams[BIdx].ldc * sizeof(float),
sizeof(float),
-std::numeric_limits<float>::max(), std::numeric_limits<float>::max()
);
}
});
}
129 changes: 100 additions & 29 deletions onnxruntime/test/mlas/unittest/test_dynamic_qgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,27 @@
#include "test_util.h"
#include "core/mlas/lib/mlasi.h" // for MLAS_CPUIDINFO

class MlasDynamicQgemmTest {
#include <cmath>
#include <limits>

class MlasDynamicQgemmTestBase {
private:
MatrixGuardBuffer<float> buffer_a;
MatrixGuardBuffer<float> buffer_bf;
MatrixGuardBuffer<int8_t> buffer_bq;
MatrixGuardBuffer<float> buffer_c;
MatrixGuardBuffer<float> buffer_c_ref;

public:
void Test(size_t M, size_t N, size_t K, size_t BatchSize) {
protected:
void Run(size_t M, size_t N, size_t K, size_t BatchSize,
MLAS_THREADPOOL* threadpool, bool require_threadpool, const char* run_tag) {
// Currently, MlasDynamicQGemmBatch() and associated functions require SME or else they are no-ops.
if (!MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()) {
if (!MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME())
GTEST_SKIP() << "MlasDynamicQGemmBatch() requires ARM64 SME but it was not detected. Skipping test.";
}
if (require_threadpool && threadpool == nullptr)
GTEST_SKIP() << "Dynamic QGEMM threading path requested but no MLAS thread pool is available.";

// Setup buffers for holding various data

float* A = buffer_a.GetBuffer(M * K * BatchSize);
// Buffer for holding floating point version of weight matrix
float* Bf = buffer_bf.GetBuffer(K * N * BatchSize);
Expand All @@ -44,6 +48,10 @@ class MlasDynamicQgemmTest {
// Quantize Bf → Bq and compute per-column scale and bias per batch
std::vector<std::vector<float>> b_scale_batches(BatchSize, std::vector<float>(N));
std::vector<std::vector<float>> b_bias_batches(BatchSize, std::vector<float>(N, 0.0f));
std::vector<std::vector<int8_t>> a_quant_batches(
BatchSize, std::vector<int8_t>(M * K));
std::vector<std::vector<float>> a_scale_batches(BatchSize, std::vector<float>(M));
std::vector<std::vector<int32_t>> a_zero_point_batches(BatchSize, std::vector<int32_t>(M));

for (size_t b = 0; b < BatchSize; ++b) {
for (size_t n = 0; n < N; ++n) {
Expand All @@ -66,6 +74,42 @@ class MlasDynamicQgemmTest {
}
}

// Quantize A rows to match the dynamic quantization performed by the kernel.
for (size_t b = 0; b < BatchSize; ++b) {
for (size_t m = 0; m < M; ++m) {
float min_val = std::numeric_limits<float>::max();
float max_val = std::numeric_limits<float>::lowest();
for (size_t k = 0; k < K; ++k) {
float v = A[b * M * K + m * K + k];
min_val = std::min(min_val, v);
max_val = std::max(max_val, v);
}
float rmin = std::min(0.0f, min_val);
float rmax = std::max(0.0f, max_val);
float inv_scale = (rmax == rmin) ? 1.0f : 255.0f / (rmax - rmin);
float scale = inv_scale ? 1.0f / inv_scale : 0.0f;
float descaled_min = rmin * inv_scale;
float descaled_max = rmax * inv_scale;
float zero_point_from_min_error = -128.0f + descaled_min;
float zero_point_from_max_error = 127.0f + descaled_max;
float zero_point = (zero_point_from_min_error + zero_point_from_max_error > 0.0f)
? (-128.0f - descaled_min)
: (127.0f - descaled_max);
zero_point = std::clamp(zero_point, -128.0f, 127.0f);
int32_t zp = static_cast<int32_t>(std::nearbyint(zero_point));

a_scale_batches[b][m] = scale;
a_zero_point_batches[b][m] = zp;

for (size_t k = 0; k < K; ++k) {
float v = A[b * M * K + m * K + k];
int32_t q = static_cast<int32_t>(std::round(v * inv_scale)) + zp;
q = std::clamp(q, -128, 127);
a_quant_batches[b][m * K + k] = static_cast<int8_t>(q);
}
}
}

// Prepare kernel parameters
MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS shape{M, N, K};
std::vector<uint8_t> packed_b_storage(BatchSize * MlasDynamicQgemmPackBSize(N, K));
Expand All @@ -86,16 +130,17 @@ class MlasDynamicQgemmTest {
params[b].PackedB = packed_b;
}

// call MlasDynamicQGemmBatch Function
MlasDynamicQGemmBatch(shape, params.data(), BatchSize, nullptr);

// Compute reference result
for (size_t b = 0; b < BatchSize; ++b) {
for (size_t m = 0; m < M; ++m) {
for (size_t n = 0; n < N; ++n) {
float sum = 0.0f;
const float a_scale = a_scale_batches[b][m];
const int32_t a_zero_point = a_zero_point_batches[b][m];
for (size_t k = 0; k < K; ++k) {
float a = A[b * M * K + m * K + k];
int32_t a_q = static_cast<int32_t>(a_quant_batches[b][m * K + k]);
float a = static_cast<float>(a_q - a_zero_point) * a_scale;
float bval = static_cast<float>(Bq[b * K * N + k * N + n]) * b_scale_batches[b][n];
sum += a * bval;
}
Expand All @@ -104,45 +149,67 @@ class MlasDynamicQgemmTest {
}
}

// Validate results
for (size_t i = 0; i < M * N * BatchSize; ++i) {
float abs_c_ref = std::abs(CRef[i]);
float dynamic_rel_tol = (K <= 4) ? 0.05f : 0.03f;
float rel_tol = dynamic_rel_tol * std::max(abs_c_ref, 1.0f);
float abs_tol = 3.0f;
float allowed = std::max(rel_tol, abs_tol);
float diff = std::abs(C[i] - CRef[i]);
ASSERT_LE(diff, allowed);
}
}
std::fill(C, C + M * N * BatchSize, 0.0f);
MlasDynamicQGemmBatch(shape, params.data(), BatchSize, threadpool);

static const char* GetTestSuiteName() {
return "DynamicQgemm";
// Validate results
auto validate = [&](const char* tag) {
SCOPED_TRACE(tag);
for (size_t i = 0; i < M * N * BatchSize; ++i) {
float abs_c_ref = std::abs(CRef[i]);
float dynamic_rel_tol = (K <= 4) ? 0.05f : 0.03f;
float rel_tol = dynamic_rel_tol * std::max(abs_c_ref, 1.0f);
float abs_tol = 3.0f;
float allowed = std::max(rel_tol, abs_tol);
float diff = std::abs(C[i] - CRef[i]);
ASSERT_LE(diff, allowed);
}
};

validate(run_tag);
}
};

class DynamicQgemmExecuteTest : public MlasTestFixture<MlasDynamicQgemmTest> {
class MlasDynamicQgemmSingleThreadTest : public MlasDynamicQgemmTestBase {
public:
void Test(size_t M, size_t N, size_t K, size_t BatchSize) {
Run(M, N, K, BatchSize, /*threadpool*/ nullptr, /*require_threadpool*/ false, "SingleThread");
}
static const char* GetTestSuiteName() { return "DynamicQgemmSingleThread"; }
};

class MlasDynamicQgemmThreadPoolTest : public MlasDynamicQgemmTestBase {
public:
void Test(size_t M, size_t N, size_t K, size_t BatchSize) {
MLAS_THREADPOOL* tp = GetMlasThreadPool();
if (!tp) GTEST_SKIP() << "Mlas thread pool not available";
Run(M, N, K, BatchSize, tp, /*require_threadpool*/ true, "ThreadPool");
}
static const char* GetTestSuiteName() { return "DynamicQgemmThreaded"; }
};

template <typename TMlasTester>
class DynamicQgemmExecuteTest : public MlasTestFixture<TMlasTester> {
public:
DynamicQgemmExecuteTest(size_t M, size_t N, size_t K, size_t BatchSize)
: M_(M), N_(N), K_(K), BatchSize_(BatchSize) {}

void TestBody() override {
this->mlas_tester->Test(M_, N_, K_, BatchSize_);
MlasTestFixture<TMlasTester>::mlas_tester->Test(M_, N_, K_, BatchSize_);
}
static size_t RegisterSingleTest(size_t M, size_t N, size_t K, size_t BatchSize) {
std::stringstream ss;
ss << "M" << M << "_N" << N << "_K" << K << "_B" << BatchSize;

std::string test_name = ss.str();

testing::RegisterTest(
MlasDynamicQgemmTest::GetTestSuiteName(),
TMlasTester::GetTestSuiteName(),
test_name.c_str(),
nullptr,
test_name.c_str(),
__FILE__,
__LINE__,
[=]() -> MlasTestFixture<MlasDynamicQgemmTest>* {
[=]() -> MlasTestFixture<TMlasTester>* {
return new DynamicQgemmExecuteTest(M, N, K, BatchSize);
});

Expand All @@ -166,7 +233,11 @@ class DynamicQgemmExecuteTest : public MlasTestFixture<MlasDynamicQgemmTest> {
size_t M_, N_, K_, BatchSize_;
};

static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) {
return DynamicQgemmExecuteTest::RegisterAll(is_short_execute);
static UNUSED_VARIABLE bool added_single = AddTestRegister([](bool is_short_execute) {
return DynamicQgemmExecuteTest<MlasDynamicQgemmSingleThreadTest>::RegisterAll(is_short_execute);
});

static UNUSED_VARIABLE bool added_threaded = AddTestRegister([](bool is_short_execute) {
return DynamicQgemmExecuteTest<MlasDynamicQgemmThreadPoolTest>::RegisterAll(is_short_execute);
});
#endif
Loading