Skip to content

Commit 2d2a3e5

Browse files
NEON kernels for NCHWc Convolution and Pooling (#25580)
### Description This PR implements optimized Arm NEON kernels for NCHWc (channels-last with channel blocking) convolution and pooling operations in MLAS, significantly improving performance on Arm64 platforms. ### Motivation and Context Fixes #24790 The new NCHWc kernels improve performance by 5-6x, depending on the configuration of threads, model, etc. For example, here is the performance gain witnessed during mobilenet inference: Focus on the "Number of inferences per second" (93 inf/s -> 498 inf/s) <details> <summary>System configuration</summary> ``` Architecture: aarch64 CPU op-mode(s): 64-bit Byte Order: Little Endian CPU(s): 64 On-line CPU(s) list: 0-63 Vendor ID: ARM Model name: Neoverse-V2 Model: 1 Thread(s) per core: 1 Core(s) per socket: 64 Socket(s): 1 Stepping: r0p1 BogoMIPS: 2000.00 Flags: fp asimd evtstrm aes pmull sha1 sha2 crc32 atomics fphp asimdhp cpuid asimdrdm jscvt fcma lrcpc dcpop sha3 asimddp sha512 sve asimdfhm dit uscat ilrcpc flagm ssbs sb paca pacg dcpodp sve2 sveaes svepmull svebitperm svesha3 flagm2 frint svei8mm svebf16 i8mm bf16 dgh rng bti Caches (sum of all): L1d: 4 MiB (64 instances) L1i: 4 MiB (64 instances) L2: 128 MiB (64 instances) L3: 36 MiB (1 instance) NUMA: NUMA node(s): 1 NUMA node0 CPU(s): 0-63 Vulnerabilities: Gather data sampling: Not affected Itlb multihit: Not affected L1tf: Not affected Mds: Not affected Meltdown: Not affected Mmio stale data: Not affected Reg file data sampling: Not affected Retbleed: Not affected Spec rstack overflow: Not affected Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl Spectre v1: Mitigation; __user pointer sanitization Spectre v2: Not affected Srbds: Not affected Tsx async abort: Not affected ``` </details> <details> <summary>Perf with current upstream kernels</summary> ``` ./build/Linux/Release/onnxruntime_perf_test -x 32 -I -m times -r 1000 ~/scripts/mobilenet.onnx Setting intra_op_num_threads to 32 Session creation time cost: 0.0238608 s First inference time cost: 11 ms Total inference time cost: 10.7458 s Total inference requests: 1000 Average inference time cost: 10.7458 ms Total inference run time: 10.7465 s Number of inferences per second: 93.0534 Avg CPU usage: 50 % Peak working set size: 70410240 bytes Avg CPU usage:50 Peak working set size:70410240 Runs:1000 Min Latency: 0.0106707 s Max Latency: 0.0113617 s P50 Latency: 0.0107453 s P90 Latency: 0.0107695 s P95 Latency: 0.0107785 s P99 Latency: 0.0107965 s P999 Latency: 0.0113617 s ``` </details> <details> <summary>Perf with NCHWc kernels</summary> ``` ./build/Linux/Release/onnxruntime_perf_test -x 32 -I -m times -r 1000 ~/scripts/mobilenet.onnx Setting intra_op_num_threads to 32 Session creation time cost: 0.0358121 s First inference time cost: 2 ms Total inference time cost: 2.00561 s Total inference requests: 1000 Average inference time cost: 2.00561 ms Total inference run time: 2.00607 s Number of inferences per second: 498.488 Avg CPU usage: 50 % Peak working set size: 92467200 bytes Avg CPU usage:50 Peak working set size:92467200 Runs:1000 Min Latency: 0.00198387 s Max Latency: 0.00204784 s P50 Latency: 0.00200537 s P90 Latency: 0.0020155 s P95 Latency: 0.00201822 s P99 Latency: 0.0020251 s P999 Latency: 0.00204784 s ``` </details> Happy to run further performance tests as required.
1 parent 323c87a commit 2d2a3e5

File tree

8 files changed

+874
-11
lines changed

8 files changed

+874
-11
lines changed

.github/workflows/android.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ jobs:
7171
run: |
7272
set -e -x
7373
BINARY_SIZE_THRESHOLD_ARGS=""
74-
echo "Binary size threshold in bytes: 1436672"
75-
BINARY_SIZE_THRESHOLD_ARGS="--threshold_size_in_bytes 1436672"
74+
echo "Binary size threshold in bytes: 1722565"
75+
BINARY_SIZE_THRESHOLD_ARGS="--threshold_size_in_bytes 1722565"
7676
7777
# Ensure ANDROID_NDK_HOME is available and get its real path
7878
if [ -z "$ANDROID_NDK_HOME" ]; then

cmake/onnxruntime_mlas.cmake

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@ function(setup_mlas_source_for_windows)
109109
${MLAS_SRC_DIR}/eltwise_kernel_neon.cpp
110110
${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp
111111
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp
112+
${MLAS_SRC_DIR}/sconv_kernel_neon.cpp
113+
${MLAS_SRC_DIR}/spool_kernel_neon.cpp
112114
)
113115

114116
set(mlas_platform_preprocess_srcs
@@ -431,6 +433,8 @@ else()
431433
${MLAS_SRC_DIR}/eltwise_kernel_neon.h
432434
${MLAS_SRC_DIR}/eltwise_kernel_neon.cpp
433435
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp
436+
${MLAS_SRC_DIR}/sconv_kernel_neon.cpp
437+
${MLAS_SRC_DIR}/spool_kernel_neon.cpp
434438
)
435439
if (onnxruntime_USE_KLEIDIAI)
436440
setup_kleidiai()

onnxruntime/core/mlas/lib/mlasi.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -949,6 +949,15 @@ extern "C" {
949949
#if defined(__aarch64__) && defined(__linux__)
950950
MLAS_SBGEMM_FLOAT_KERNEL MlasSbgemmKernelZero;
951951
MLAS_SBGEMM_FLOAT_KERNEL MlasSbgemmKernelAdd;
952+
#endif
953+
#if defined(MLAS_TARGET_ARM64)
954+
MLAS_CONV_FLOAT_KERNEL MlasConvNchwFloatKernelNeon;
955+
MLAS_CONV_FLOAT_KERNEL MlasConvNchwcFloatKernelNeon;
956+
MLAS_CONV_DEPTHWISE_FLOAT_KERNEL MlasConvDepthwiseFloatKernelNeon;
957+
MLAS_CONV_POINTWISE_FLOAT_KERNEL MlasConvPointwiseFloatKernelNeon;
958+
MLAS_POOL_FLOAT_KERNEL MlasPoolMaximumFloatKernelNeon;
959+
MLAS_POOL_FLOAT_KERNEL MlasPoolAverageExcludePadFloatKernelNeon;
960+
MLAS_POOL_FLOAT_KERNEL MlasPoolAverageIncludePadFloatKernelNeon;
952961
#endif
953962
MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernelZero;
954963
MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernelAdd;
@@ -1335,6 +1344,12 @@ struct MLAS_PLATFORM {
13351344
const MLAS_GEMM_QUANT_DISPATCH* GemmU8U8Dispatch;
13361345
const MLAS_GEMM_QUANT_DISPATCH* GemmU8S8Dispatch;
13371346
const MLAS_GEMM_QUANT_DISPATCH* GemmS8S8Dispatch;
1347+
MLAS_CONV_FLOAT_KERNEL* ConvNchwFloatKernel;
1348+
MLAS_CONV_FLOAT_KERNEL* ConvNchwcFloatKernel;
1349+
MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* ConvDepthwiseFloatKernel;
1350+
MLAS_CONV_POINTWISE_FLOAT_KERNEL* ConvPointwiseFloatKernel;
1351+
MLAS_POOL_FLOAT_KERNEL* PoolFloatKernel[MlasPoolingKindCount];
1352+
uint32_t NchwcBlockSize;
13381353
#endif
13391354
const MLAS_SYMM_QGEMM_DISPATCH* SymmQgemmDispatch{nullptr};
13401355

@@ -1395,6 +1410,7 @@ struct MLAS_PLATFORM {
13951410
int32_t MaximumThreadCount;
13961411
#elif defined(MLAS_TARGET_ARM64)
13971412
static constexpr int32_t MaximumThreadCount = MLAS_MAXIMUM_THREAD_COUNT * 4;
1413+
static constexpr size_t MLAS_NEON_NCHWC_BLOCK_SIZE = 16;
13981414
#else
13991415
static constexpr int32_t MaximumThreadCount = MLAS_MAXIMUM_THREAD_COUNT;
14001416
#endif

onnxruntime/core/mlas/lib/platform.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,15 @@ Return Value:
558558
this->SoftmaxDispatch = &MlasSoftmaxDispatchNeon;
559559
this->EltwiseDispatch = &MlasEltwiseDispatchNeon;
560560

561+
this->ConvNchwFloatKernel = MlasConvNchwFloatKernelNeon;
562+
this->ConvNchwcFloatKernel = MlasConvNchwcFloatKernelNeon;
563+
this->ConvDepthwiseFloatKernel = MlasConvDepthwiseFloatKernelNeon;
564+
this->ConvPointwiseFloatKernel = MlasConvPointwiseFloatKernelNeon;
565+
this->PoolFloatKernel[MlasMaximumPooling] = MlasPoolMaximumFloatKernelNeon;
566+
this->PoolFloatKernel[MlasAveragePoolingExcludePad] = MlasPoolAverageExcludePadFloatKernelNeon;
567+
this->PoolFloatKernel[MlasAveragePoolingIncludePad] = MlasPoolAverageIncludePadFloatKernelNeon;
568+
this->NchwcBlockSize = MLAS_NEON_NCHWC_BLOCK_SIZE;
569+
561570
//
562571
// Check if the processor supports ASIMD dot product instructions.
563572
//

onnxruntime/core/mlas/lib/sconv.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/*++
2+
3+
Copyright (c) Microsoft Corporation. All rights reserved.
4+
5+
Licensed under the MIT License.
6+
7+
Module Name:
8+
9+
sconv.h
10+
11+
Abstract:
12+
13+
This module defines convolution kernel flags for configuring convolution
14+
operations including output accumulation, bias addition, and activations.
15+
16+
--*/
17+
18+
//
19+
// Define the convolution kernel flags.
20+
//
21+
22+
#define MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT 0x00000001
23+
#define MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION 0x00000002
24+
#define MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION 0x00000004
25+
#define MLAS_CONV_KERNEL_FLAG_OTHER_ACTIVATION 0x00000008

0 commit comments

Comments
 (0)