Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ protoc_mac_universal;https://github.com/protocolbuffers/protobuf/releases/downlo
psimd;https://github.com/Maratyszcza/psimd/archive/072586a71b55b7f8c584153d223e95687148a900.zip;1f5454b01f06f9656b77e4a5e2e31d7422487013
pthreadpool;https://github.com/google/pthreadpool/archive/dcc9f28589066af0dbd4555579281230abbf74dd.zip;533a77943203ef15ca608bcd9dbe2c94da7451d2
pybind11;https://github.com/pybind/pybind11/archive/refs/tags/v2.13.6.zip;f780292da9db273c8ef06ccf5fd4b623624143e9
pytorch_cpuinfo;https://github.com/pytorch/cpuinfo/archive/877328f188a3c7d1fa855871a278eb48d530c4c0.zip;9152d4bf6b8bde9f19b116de3bd8a745097ed9df
pytorch_cpuinfo;https://github.com/pytorch/cpuinfo/archive/f858c30bcb16f8effd5ff46996f0514539e17abc.zip;66a964eda7de60c925e2e26f71f9bbe31698997b
re2;https://github.com/google/re2/archive/refs/tags/2024-07-02.zip;646e1728269cde7fcef990bf4a8e87b047882e88
safeint;https://github.com/dcleblanc/SafeInt/archive/refs/tags/3.0.28.zip;23f252040ff6cb9f1fd18575b32fa8fb5928daac
tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2493f0949dc4be6b6a45e81.zip;67b833913605a4f3f499894ab11528a702c2b381
Expand Down
4 changes: 2 additions & 2 deletions cmake/vcpkg-ports/cpuinfo/portfile.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ endif()
vcpkg_from_github(
OUT_SOURCE_PATH SOURCE_PATH
REPO pytorch/cpuinfo
REF 877328f188a3c7d1fa855871a278eb48d530c4c0
SHA512 b6d5a9ce9996eee3b2f09f39115f7ae178fe4d4814cc35b049a59d04a82228e268aa52d073c307ccb56a427428622940e1c77f004c99851dfca0d3a5d803658b
REF f858c30bcb16f8effd5ff46996f0514539e17abc
SHA512 cd7c0c1ea59fac69f2746f65f59656798eeb87410c304ac9d3b3d26ebea4f4124d1426c10fb4b87ff5f93f367ea10d63337f519ee3c3f8fefbb4b7ebf6438130
HEAD_REF master
PATCHES
patch_cpuinfo_h_for_arm64ec.patch
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/common/cpuid_info.h" // for CPUIDInfo::GetCPUIDInfo().HasArm_SME2()
#include "core/common/narrow.h"
#include "core/common/safeint.h"
#include "core/mlas/inc/mlas.h"
Expand Down Expand Up @@ -213,9 +212,7 @@ class DynamicQuantizeMatMul final : public MatMulIntegerToFloatBase {
}
}

// Currently, MlasDynamicQGemmBatch() and associated functions require SME2 or else they are no-ops.
// We check that here too before attempting to use them.
if (!CPUIDInfo::GetCPUIDInfo().HasArm_SME2()) {
if (!MlasIsDynamicQGemmAvailable()) {
can_use_dynamic_quant_mlas_ = false;
}

Expand Down
33 changes: 8 additions & 25 deletions onnxruntime/core/common/cpuid_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,9 @@
#elif defined(_WIN32) // ^ defined(__linux__)

void CPUIDInfo::ArmWindowsInit() {
// Read MIDR and ID_AA64ISAR1_EL1 register values from Windows registry
// Read MIDR register values from Windows registry
// There should be one per CPU
std::vector<uint64_t> midr_values{}, id_aa64isar1_el1_values{};
std::vector<uint64_t> midr_values{};

Check warning on line 242 in onnxruntime/core/common/cpuid_info.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/common/cpuid_info.cc:242: Add #include <vector> for vector<> [build/include_what_you_use] [4]

// TODO!! Don't support multiple processor group yet!!
constexpr int MAX_CORES = 64;
Expand Down Expand Up @@ -272,17 +272,7 @@
break;
}

uint64_t id_aa64isar1_el1_value;
data_size = sizeof(id_aa64isar1_el1_value);

// CP 4031 corresponds to ID_AA64ISAR1_EL1 register
if (::RegGetValueA(HKEY_LOCAL_MACHINE, processor_subkey, "CP 4031", RRF_RT_REG_QWORD,
nullptr, &id_aa64isar1_el1_value, &data_size) != ERROR_SUCCESS) {
break;
}

midr_values.push_back(midr_value);
id_aa64isar1_el1_values.push_back(id_aa64isar1_el1_value);
}

// process midr_values
Expand All @@ -308,22 +298,15 @@
}
}

has_arm_neon_i8mm_ = std::all_of(
id_aa64isar1_el1_values.begin(), id_aa64isar1_el1_values.end(),
[](uint64_t id_aa64isar1_el1_value) {
// I8MM, bits [55:52]
return ((id_aa64isar1_el1_value >> 52) & 0xF) != 0;
});

has_arm_neon_dot_ = (IsProcessorFeaturePresent(PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE) != 0);

#if defined(CPUINFO_SUPPORTED)
if (pytorch_cpuinfo_init_) {
has_arm_neon_dot_ = cpuinfo_has_arm_neon_dot();
has_fp16_ = cpuinfo_has_arm_neon_fp16_arith();
// cpuinfo_has_arm_i8mm() doesn't work on Windows yet. See https://github.com/pytorch/cpuinfo/issues/279.
// has_arm_neon_i8mm_ = cpuinfo_has_arm_i8mm();
has_arm_sve_i8mm_ = cpuinfo_has_arm_sve() && has_arm_neon_i8mm_;
has_arm_neon_i8mm_ = cpuinfo_has_arm_i8mm();
has_arm_sve_i8mm_ = cpuinfo_has_arm_sve() && cpuinfo_has_arm_i8mm();
has_arm_neon_bf16_ = cpuinfo_has_arm_neon_bf16();
has_arm_sme_ = cpuinfo_has_arm_sme();
has_arm_sme2_ = cpuinfo_has_arm_sme2();
}
#endif // defined(CPUINFO_SUPPORTED)
}
Expand Down Expand Up @@ -397,4 +380,4 @@
#endif
#endif // defined(CPUIDINFO_ARCH_ARM)
}
} // namespace onnxruntime
} // namespace onnxruntime
2 changes: 1 addition & 1 deletion onnxruntime/core/common/cpuid_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,4 +171,4 @@ class CPUIDInfo {
uint32_t vendor_id_;
};

} // namespace onnxruntime
} // namespace onnxruntime
10 changes: 10 additions & 0 deletions onnxruntime/core/mlas/inc/mlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,7 @@ MlasGemm(
{
MlasGemmBatch(Shape, &DataParams, 1, ThreadPool);
}

/**
* @brief Parameters that define the shape of a dynamically quantized GEMM operation.
*
Expand All @@ -646,6 +647,7 @@ struct MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS {
size_t N = 0; /**< Column size of matrix B */
size_t K = 0; /**< Column size of matrix A and Row size of matrix B */
};

/**
* @brief Parameters that define the data buffers and layout for a dynamic quant GEMM.
*
Expand Down Expand Up @@ -680,6 +682,14 @@ MlasDynamicQGemm (
MlasDynamicQGemmBatch(Shape, DataParams, 1, ThreadPool);
}

/**
* @brief Determines whether a dynamic quantized GEMM implementation is available on the current platform.
*
* MlasDynamicQGemm() and MlasDynamicQGemmBatch() should only be called if this function returns true.
*/
bool
MLASCALL
MlasIsDynamicQGemmAvailable();

//
// Symmetric QGEMM has limited buffer overrun.
Expand Down
19 changes: 15 additions & 4 deletions onnxruntime/core/mlas/lib/qgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,17 @@ MlasGemmBatch(
});
}

bool
MLASCALL
MlasIsDynamicQGemmAvailable()
{
#if defined(USE_KLEIDIAI) && !defined(_MSC_VER)
return ArmKleidiAI::UseSME2;
#else
return false;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add a TODO here with relevant commentary and a reminder to adjust these checks once KAI support for Windows is added ?

#endif
}

void
MLASCALL
MlasDynamicQGemmBatch (
Expand All @@ -211,7 +222,7 @@ MlasDynamicQGemmBatch (
) {
#if defined(USE_KLEIDIAI) && !defined(_MSC_VER)
//No fallback and putting in guards. This implementation is SME2 specific.
if(ArmKleidiAI::UseSME2){
if (ArmKleidiAI::UseSME2) {
ArmKleidiAI::MlasDynamicQGemmBatch(Shape, DataParams, BatchN, ThreadPool);
}
#endif
Expand Down Expand Up @@ -336,7 +347,7 @@ MlasDynamicQgemmPackBSize(
#if defined(USE_KLEIDIAI) && !defined(_MSC_VER)
//No fallback available
//TODO: Insert Override
if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){//Still require this since no override
if (ArmKleidiAI::UseSME2) { //Still require this since no override
bytes = ArmKleidiAI::MlasDynamicQgemmPackBSize(N, K);
}
#endif
Expand Down Expand Up @@ -407,7 +418,7 @@ Return Value:
~(BufferAlignment - 1);
// If this gemm B argument is used in a dynamically quantization gemm operation we can optimize for
// this use case. Concat both packed representations for later decision. This allows for cases later
// where we still have the prepack at the cost of some memory otherwise we can use the qgemm quantization
// where we still have the prepack at the cost of some memory otherwise we can use the qgemm quantization
// for better performance
return AlignedBytesRequired + MlasDynamicQgemmPackBSize(N, K);
}
Expand All @@ -425,7 +436,7 @@ MlasDynamicQgemmPackB(
{
#if defined(USE_KLEIDIAI) && !defined(_MSC_VER)
//No fallback
if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){//Still require this since no override
if (ArmKleidiAI::UseSME2) { //Still require this since no override
ArmKleidiAI::MlasDynamicQgemmPackB(N, K, B, Scales, Bias, PackedB);
}
#endif
Expand Down
7 changes: 3 additions & 4 deletions onnxruntime/test/mlas/unittest/test_dynamic_qgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
// Currently this test only applies to KleidiAI Guard against it running in any other situation
#if defined(USE_KLEIDIAI) && !defined(_MSC_VER)

#include "mlas.h"
#include "test_util.h"
#include "core/mlas/lib/mlasi.h" // for MLAS_CPUIDINFO

class MlasDynamicQgemmTest {
private:
Expand All @@ -20,9 +20,8 @@ class MlasDynamicQgemmTest {

public:
void Test(size_t M, size_t N, size_t K, size_t BatchSize) {
// Currently, MlasDynamicQGemmBatch() and associated functions require SME2 or else they are no-ops.
if (!MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2()) {
GTEST_SKIP() << "MlasDynamicQGemmBatch() requires ARM64 SME2 but it was not detected. Skipping test.";
if (!MlasIsDynamicQGemmAvailable()) {
GTEST_SKIP() << "MlasDynamicQGemmBatch() is not supported on this platform. Skipping test.";
}

// Setup buffers for holding various data
Expand Down
Loading