Skip to content

Commit bd8f781

Browse files
milpuz01xhcaoquic-calvnguyquic_calvnguyJiawei-Shao
authored
mlas/arm64: add NEON conv asm kernels and tune NCHWC kernel selection (#27099)
## Overview This PR adds ARM64 NEON assembly micro‑kernels for NCHW, depthwise, and pointwise convolution, wires them into the MLAS build, and adds shape‑based selection heuristics for NCHWC depthwise/pointwise to favor the asm kernels in safe cases (stride‑1 pointwise; wider depthwise outputs). The BF16 path is unchanged. ## Key changes - cmake/onnxruntime_mlas.cmake - Add new AArch64 assembly sources for NCHW, depthwise, and pointwise conv to the MLAS build. - onnxruntime/core/mlas/lib/aarch64/SconvKernelNeon.S - New vectorised NCHW convolution micro‑kernel. - onnxruntime/core/mlas/lib/aarch64/SconvDepthwiseKernelNeon.S - New vectorised depthwise micro‑kernel (fast path for in‑bounds loads, slow path for padding). - onnxruntime/core/mlas/lib/aarch64/SconvPointwiseKernelNeon.S - New vectorised pointwise micro‑kernel (multi‑output reuse). - onnxruntime/core/mlas/lib/mlasi.h, onnxruntime/core/mlas/lib/platform.cpp - Declare/register new asm kernels and prefer them on ARM64. - onnxruntime/core/mlas/lib/snchwc.cpp - Heuristics: use pointwise asm when StrideHeight == 1 && StrideWidth == 1 and OutputThisIteration >= 4; use depthwise asm when OutputWidth >= 4. - onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp - Include fix for the conv kernel flags header. ## Performance Numbers below are expressed as multipliers vs the non‑NCHWC baseline (same model and perf_test settings): Baseline (no `--enable_arm_neon_nchwc`) - 8 cores: 1.00× - 16 cores: 1.00× With `--enable_arm_neon_nchwc` (no asm additions/heuristics) - 8 cores: 1.18× - 16 cores: 1.24× With this PR (asm kernels + heuristics) - 8 cores: 1.77× - 16 cores: 2.54× ## Testing - `./build.sh --config Release --build_shared_lib --parallel --compile_no_warning_as_error --skip_submodule_sync --skip_tests --enable_pybind --build_wheel --enable_arm_neon_nchwc` - `OMP_NUM_THREADS=8 ./build/Linux/Release/onnxruntime_perf_test -I -m times -r 1000 --x 8 ~/mobilenetv2-7.onnx` --------- Signed-off-by: Milos Puzovic <milos.puzovic@arm.com> Co-authored-by: xhcao <xinghua.cao@intel.com> Co-authored-by: quic-calvnguy <quic_calvnguy@quicinc.com> Co-authored-by: quic_calvnguy <quic_calvnguy@quic_inc.com> Co-authored-by: Jiawei Shao <jiawei.shao@intel.com>
1 parent 0c2502c commit bd8f781

File tree

8 files changed

+1648
-7
lines changed

8 files changed

+1648
-7
lines changed

cmake/onnxruntime_mlas.cmake

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,16 @@ function (setup_arm_neon_nchwc)
328328
${MLAS_SRC_DIR}/sconv_nchwc_kernel_neon.cpp
329329
${MLAS_SRC_DIR}/spool_nchwc_kernel_neon.cpp
330330
)
331+
if(NOT WIN32)
332+
target_sources(onnxruntime_mlas PRIVATE
333+
# Hand written AArch64 micro-kernel for NCHW convolution. Using a
334+
# separate assembly file allows tighter control over register allocation
335+
# and avoids the overhead of C++/intrinsics based code generation.
336+
${MLAS_SRC_DIR}/aarch64/SconvKernelNeon.S
337+
${MLAS_SRC_DIR}/aarch64/SconvDepthwiseKernelNeon.S
338+
${MLAS_SRC_DIR}/aarch64/SconvPointwiseKernelNeon.S
339+
)
340+
endif()
331341
list(APPEND mlas_private_compile_definitions MLAS_USE_ARM_NEON_NCHWC)
332342
set(mlas_private_compile_definitions ${mlas_private_compile_definitions} PARENT_SCOPE)
333343
endfunction ()
Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
/*++
2+
SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates <open-source-office@arm.com>
3+
SPDX-License-Identifier: MIT
4+
5+
Module Name:
6+
7+
SconvDepthwiseKernelNeon.S
8+
9+
Abstract:
10+
11+
Optimized AArch64 assembly implementation of the depthwise convolution
12+
micro-kernel used by the NCHWc single precision path.
13+
14+
This kernel performs the following optimisations:
15+
* Produce a fast path for interior output positions where all input
16+
accesses are guaranteed to be in-bounds and can be loaded with a pair
17+
of 128-bit loads.
18+
* When an output position touches padding, only the affected 4-wide
19+
lanes are checked individually and loaded; others are zeroed. This
20+
mirrors the behavior of the C++ helper LoadInputVectorWithBounds.
21+
* Keep the multiply/accumulate operations tightly scheduled to hide the
22+
load latency.
23+
24+
The kernel computes a single output position for a 16 channel block and is
25+
repeatedly invoked by the high level dispatch code.
26+
27+
--*/
28+
29+
#include "asmmacro.h"
30+
31+
.text
32+
33+
// Offsets for stack based parameters. AArch64 passes the first eight
34+
// arguments in registers (x0-x7). The remaining parameters are read from the
35+
// stack directly. The layout is defined by the C compiler so use constant
36+
// offsets here.
37+
38+
.equ .Ldw_InputBase, 0
39+
.equ .Ldw_InputWidth, 8
40+
.equ .Ldw_DilatedInputWidth, 16
41+
.equ .Ldw_OutputCountLeftPad, 24
42+
.equ .Ldw_OutputCount, 32
43+
.equ .Ldw_OutputCountRightPad, 40
44+
.equ .Ldw_Bias, 48
45+
.equ .Ldw_Flags, 56
46+
47+
// Prototype
48+
//
49+
// void
50+
// MlasConvDepthwiseFloatKernelNeonAsm(
51+
// const float* Input, // x0
52+
// const float* Filter, // x1
53+
// float* Output, // x2
54+
// size_t StrideWidth, // x3 (bytes)
55+
// size_t DilationWidth, // x4 (bytes)
56+
// size_t InputStride, // x5 (unused)
57+
// size_t KernelHeight, // x6
58+
// size_t KernelWidth, // x7
59+
// const float* InputBase, // [sp + 0]
60+
// size_t InputWidth, // [sp + 8] (bytes)
61+
// size_t DilatedInputWidth, // [sp + 16] (bytes)
62+
// size_t OutputCountLeftPad, // [sp + 24]
63+
// size_t OutputCount, // [sp + 32]
64+
// size_t OutputCountRightPad, // [sp + 40]
65+
// const float* Bias, // [sp + 48]
66+
// unsigned KernelFlags); // [sp + 56]
67+
//
68+
69+
FUNCTION_ENTRY MlasConvDepthwiseFloatKernelNeonAsm
70+
71+
// Load the stack parameters used in the hot loops.
72+
ldr x8, [sp,#.Ldw_InputBase] // base of valid input row
73+
ldr x9, [sp,#.Ldw_InputWidth] // row width in bytes
74+
ldr x10,[sp,#.Ldw_DilatedInputWidth] // stride between rows
75+
ldr x11,[sp,#.Ldw_OutputCountLeftPad]
76+
ldr x12,[sp,#.Ldw_OutputCount]
77+
ldr x13,[sp,#.Ldw_OutputCountRightPad]
78+
ldr x14,[sp,#.Ldw_Bias]
79+
ldr w15,[sp,#.Ldw_Flags]
80+
81+
// Preserve callee-saved registers used by this routine.
82+
stp x29,x30,[sp,#-16]!
83+
stp x27,x28,[sp,#-16]!
84+
stp x25,x26,[sp,#-16]!
85+
stp x23,x24,[sp,#-16]!
86+
stp x21,x22,[sp,#-16]!
87+
stp x19,x20,[sp,#-16]!
88+
stp d12,d13,[sp,#-16]!
89+
stp d14,d15,[sp,#-16]!
90+
91+
// Compute total number of output elements to produce.
92+
add x16,x11,x12
93+
add x16,x16,x13
94+
95+
// Load bias vectors when required; otherwise all zeros are used to
96+
// initialize the accumulators.
97+
eor v20.16b, v20.16b, v20.16b
98+
eor v21.16b, v21.16b, v21.16b
99+
eor v22.16b, v22.16b, v22.16b
100+
eor v23.16b, v23.16b, v23.16b
101+
tbz w15,#1,1f // no bias addition
102+
ldp q20,q21,[x14],#32
103+
ldp q22,q23,[x14]
104+
1:
105+
// Constant zero used by ReLU handling.
106+
eor v24.16b, v24.16b, v24.16b
107+
108+
mov x17,#0 // output index
109+
110+
// ---------------------------------------------------------------------------
111+
// Loop over output elements. Each iteration computes one output position for
112+
// 16 channels.
113+
// ---------------------------------------------------------------------------
114+
.Ldw_OutputLoop:
115+
// Start accumulators from bias or zeros.
116+
mov v0.16b, v20.16b
117+
mov v1.16b, v21.16b
118+
mov v2.16b, v22.16b
119+
mov v3.16b, v23.16b
120+
121+
mov x20,x1 // reset filter pointer
122+
mov x21,#0 // kh = 0
123+
124+
// Base pointer for this output index across the input.
125+
madd x19,x17,x3,x0 // Input + out_idx*StrideWidth
126+
127+
.Ldw_HeightLoop:
128+
// Compute [row_start,row_end) for the current kernel row.
129+
madd x22,x21,x10,x8 // row_start
130+
add x27,x22,x9 // row_end
131+
sub x29,x27,#64 // row_end - 64 (fast path)
132+
add x28,x27,#-16 // row_end - 16
133+
134+
// Base address for the first kw element on this row.
135+
madd x26,x21,x10,x19 // input for this row
136+
137+
mov x25,x7 // kw remaining
138+
139+
.Ldw_WidthLoop:
140+
// Fast path: the 16-lane load fits completely within the row.
141+
cmp x26,x22
142+
b.lo .Ldw_SlowPath
143+
cmp x26,x29
144+
bhi .Ldw_SlowPath
145+
// Load 16 input values for the current position.
146+
ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x26]
147+
b .Ldw_DoFma
148+
149+
.Ldw_SlowPath:
150+
// Zero registers and conditionally load each 4-wide vector when it is
151+
// entirely within bounds. This matches the behavior of the C++
152+
// helper LoadInputVectorWithBounds.
153+
eor v16.16b, v16.16b, v16.16b
154+
eor v17.16b, v17.16b, v17.16b
155+
eor v18.16b, v18.16b, v18.16b
156+
eor v19.16b, v19.16b, v19.16b
157+
158+
mov x23,x26
159+
cmp x23,x22
160+
b.lt 2f
161+
cmp x23,x28
162+
b.hi 2f
163+
ldr q16,[x23]
164+
2:
165+
add x23,x26,#16
166+
cmp x23,x22
167+
b.lt 3f
168+
cmp x23,x28
169+
b.hi 3f
170+
ldr q17,[x23]
171+
3:
172+
add x23,x26,#32
173+
cmp x23,x22
174+
b.lt 4f
175+
cmp x23,x28
176+
b.hi 4f
177+
ldr q18,[x23]
178+
4:
179+
add x23,x26,#48
180+
cmp x23,x22
181+
b.lt 5f
182+
cmp x23,x28
183+
b.hi 5f
184+
ldr q19,[x23]
185+
5:
186+
187+
.Ldw_DoFma:
188+
// Load filter block and update accumulators.
189+
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x20], #64
190+
fmla v0.4s, v16.4s, v12.4s
191+
fmla v1.4s, v17.4s, v13.4s
192+
fmla v2.4s, v18.4s, v14.4s
193+
fmla v3.4s, v19.4s, v15.4s
194+
195+
add x26,x26,x4 // advance to next kw
196+
subs x25,x25,#1
197+
b.ne .Ldw_WidthLoop
198+
199+
add x21,x21,#1
200+
cmp x21,x6
201+
blt .Ldw_HeightLoop
202+
203+
// Compute destination pointer for this output element.
204+
add x23,x2,x17,lsl #6 // 16 floats per output
205+
206+
// Accumulate existing output when requested.
207+
tbz w15,#0,6f
208+
ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x23]
209+
fadd v0.4s, v0.4s, v16.4s
210+
fadd v1.4s, v1.4s, v17.4s
211+
fadd v2.4s, v2.4s, v18.4s
212+
fadd v3.4s, v3.4s, v19.4s
213+
6:
214+
// Optional ReLU activation.
215+
tbz w15,#2,8f
216+
fmax v0.4s, v0.4s, v24.4s
217+
fmax v1.4s, v1.4s, v24.4s
218+
fmax v2.4s, v2.4s, v24.4s
219+
fmax v3.4s, v3.4s, v24.4s
220+
8:
221+
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x23]
222+
223+
add x17,x17,#1
224+
cmp x17,x16
225+
blt .Ldw_OutputLoop
226+
227+
ldp d14,d15,[sp],#16
228+
ldp d12,d13,[sp],#16
229+
ldp x19,x20,[sp],#16
230+
ldp x21,x22,[sp],#16
231+
ldp x23,x24,[sp],#16
232+
ldp x25,x26,[sp],#16
233+
ldp x27,x28,[sp],#16
234+
ldp x29,x30,[sp],#16
235+
236+
ret
237+
238+
.end

0 commit comments

Comments
 (0)