Skip to content

Commit 64278b2

Browse files
committed
Add more opt + remove Gelu fusion for now
1 parent 7cf2da6 commit 64278b2

File tree

9 files changed

+446
-181
lines changed

9 files changed

+446
-181
lines changed

cmake/onnxruntime_mlas.cmake

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ onnxruntime_add_static_library(onnxruntime_mlas
2323
${MLAS_SRC_DIR}/qgemm.cpp
2424
${MLAS_SRC_DIR}/qdwconv.cpp
2525
${MLAS_SRC_DIR}/convolve.cpp
26+
${MLAS_SRC_DIR}/sconv_nchw_depthwise_multiplier_greater_than_1.cpp
2627
${MLAS_SRC_DIR}/convsym.cpp
2728
${MLAS_SRC_DIR}/pooling.cpp
2829
${MLAS_SRC_DIR}/transpose.cpp
@@ -115,7 +116,7 @@ function(setup_mlas_source_for_windows)
115116
${MLAS_SRC_DIR}/eltwise_kernel_neon.cpp
116117
${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp
117118
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp
118-
${MLAS_SRC_DIR}/sconv_nchw_kernel_neon.cpp
119+
${MLAS_SRC_DIR}/sconv_nchw_depthwise_multiplier_1.cpp
119120
)
120121

121122
set(mlas_platform_preprocess_srcs
@@ -488,7 +489,7 @@ else()
488489
${MLAS_SRC_DIR}/eltwise_kernel_neon.h
489490
${MLAS_SRC_DIR}/eltwise_kernel_neon.cpp
490491
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp
491-
${MLAS_SRC_DIR}/sconv_nchw_kernel_neon.cpp
492+
${MLAS_SRC_DIR}/sconv_nchw_depthwise_multiplier_1.cpp
492493
)
493494

494495
# Conditionally add the SVE implementation if compiler supports it

onnxruntime/core/mlas/inc/mlas.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -877,6 +877,7 @@ enum MLAS_CONV_ALGORITHM {
877877
MlasConvAlgorithmGemmDirect,
878878
MlasConvAlgorithmExpandThenGemm,
879879
MlasConvAlgorithmExpandThenGemmSegmented,
880+
MlasConvAlgorithmDepthwiseWithMultiplier,
880881
#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64)
881882
MlasConvAlgorithmDepthwise,
882883
#endif

onnxruntime/core/mlas/lib/convolve.cpp

Lines changed: 126 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -892,6 +892,77 @@ Return Value:
892892

893893
#endif
894894

895+
void
896+
MlasDepthwiseWithMultiplierThreaded(
897+
void* Context,
898+
ptrdiff_t Index
899+
)
900+
/*++
901+
902+
Routine Description:
903+
904+
This routine is invoked from a worker thread to execute a segment of a
905+
convolution operation.
906+
907+
If using this, the entire convolution operation is parallelized on the
908+
(batch size * group count) parameter and this routine has logic to
909+
perform a specific thread's shard of the entire Convolution operation.
910+
911+
Arguments:
912+
913+
Context - Supplies the pointer to the context for the threaded operation.
914+
915+
Index - Supplies the current index of the threaded operation.
916+
917+
Return Value:
918+
919+
None.
920+
921+
--*/
922+
{
923+
MLAS_CONV_WORK_BLOCK* WorkBlock = (MLAS_CONV_WORK_BLOCK*)Context;
924+
925+
const MLAS_CONV_PARAMETERS* Parameters = WorkBlock->Parameters;
926+
const size_t GroupCount = Parameters->GroupCount;
927+
const size_t BatchGroupCount = Parameters->BatchCount * GroupCount;
928+
929+
size_t BatchGroupStart;
930+
size_t BatchGroupRemaining;
931+
932+
MlasPartitionWork(Index, WorkBlock->TargetThreadCount, BatchGroupCount,
933+
&BatchGroupStart, &BatchGroupRemaining);
934+
935+
size_t BatchGroupEnd = BatchGroupStart + BatchGroupRemaining;
936+
937+
const size_t FilterCount = Parameters->FilterCount;
938+
const size_t OutputSize = Parameters->OutputSize;
939+
const size_t K = Parameters->K;
940+
941+
const size_t InputGroupSize = Parameters->InputChannels * Parameters->InputSize;
942+
const size_t OutputGroupSize = FilterCount * OutputSize;
943+
const size_t FilterGroupSize = FilterCount * K;
944+
945+
const float* input = WorkBlock->Input + BatchGroupStart * InputGroupSize;
946+
float* output = WorkBlock->Output + BatchGroupStart * OutputGroupSize;
947+
948+
for (size_t bg = BatchGroupStart; bg < BatchGroupEnd; bg++) {
949+
size_t group = bg % GroupCount;
950+
951+
const float* filter = WorkBlock->Filter + group * FilterGroupSize;
952+
const float* bias = WorkBlock->Bias;
953+
if (bias != nullptr) {
954+
bias += group * FilterCount;
955+
}
956+
957+
MlasConvDepthwiseWithMultiplierFloat_CHW(Parameters, input, filter, output);
958+
MlasActivation(Parameters->Activation, output, bias, FilterCount,
959+
OutputSize, OutputSize);
960+
961+
input += InputGroupSize;
962+
output += OutputGroupSize;
963+
}
964+
}
965+
895966
inline
896967
bool
897968
MlasConvTryMultithread(
@@ -1106,7 +1177,6 @@ Return Value:
11061177
return;
11071178
}
11081179

1109-
11101180
#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64)
11111181

11121182
if (Algorithm == MlasConvAlgorithmDepthwise && ((BatchCount > 1) || (GroupCount > 1))) {
@@ -1135,6 +1205,28 @@ Return Value:
11351205

11361206
#endif
11371207

1208+
if (Algorithm == MlasConvAlgorithmDepthwiseWithMultiplier && ((BatchCount > 1) || (GroupCount > 1))) {
1209+
const size_t BatchGroupCount = BatchCount * GroupCount;
1210+
ptrdiff_t TargetThreadCount = MlasGetMaximumThreadCount(ThreadPool);
1211+
1212+
if (static_cast<size_t>(TargetThreadCount) >= BatchGroupCount) {
1213+
TargetThreadCount = static_cast<ptrdiff_t>(BatchGroupCount);
1214+
}
1215+
1216+
MLAS_CONV_WORK_BLOCK WorkBlock;
1217+
WorkBlock.Parameters = Parameters;
1218+
WorkBlock.Input = Input;
1219+
WorkBlock.Filter = Filter;
1220+
WorkBlock.Bias = Bias;
1221+
WorkBlock.WorkingBuffer = nullptr;
1222+
WorkBlock.Output = Output;
1223+
WorkBlock.TargetThreadCount = TargetThreadCount;
1224+
1225+
MlasExecuteThreaded(MlasDepthwiseWithMultiplierThreaded, &WorkBlock,
1226+
TargetThreadCount, ThreadPool);
1227+
return;
1228+
}
1229+
11381230
//
11391231
// Iterate over each batch and group.
11401232
//
@@ -1209,6 +1301,13 @@ Return Value:
12091301

12101302
#endif
12111303

1304+
case MlasConvAlgorithmDepthwiseWithMultiplier:
1305+
{
1306+
MlasConvDepthwiseWithMultiplierFloat_CHW(Parameters, Input, filter, Output);
1307+
MlasActivation(Parameters->Activation, Output, bias, FilterCount, OutputSize, OutputSize);
1308+
break;
1309+
}
1310+
12121311
case MlasConvAlgorithmExpandThenGemmSegmented:
12131312
{
12141313
//
@@ -1453,6 +1552,26 @@ Return Value:
14531552

14541553
} else {
14551554

1555+
// Commonly found in MobileNet like models, where the depthwise convolution with
1556+
// depth_multiplier = 2 is used together with 7x7 kernel shape, stride = 2 and dilation = 1.
1557+
// This is a very specific scenario, but it is worth to have a specialized kernel for it given
1558+
// the popularity of MobileNet models.
1559+
if (Dimensions == 2
1560+
// depthwise convolution
1561+
&& Parameters->GroupCount > 1
1562+
&& Parameters->InputChannels == 1
1563+
// depth_multiplier = 2
1564+
&& Parameters->FilterCount == 2
1565+
// current scope for specialized kernel is for the 7x7 kernel shape
1566+
&& Parameters->KernelShape[0] == 7 && Parameters->KernelShape[1] == 7
1567+
// keep this specialized kernel only for stride = 2x2
1568+
&& Parameters->StrideShape[0] == 2 && Parameters->StrideShape[1] == 2
1569+
// keep this specialized kernel only for dilation = 1x1
1570+
&& Parameters->DilationShape[0] == 1 && Parameters->DilationShape[1] == 1) {
1571+
Parameters->Algorithm = MlasConvAlgorithmDepthwiseWithMultiplier;
1572+
return;
1573+
}
1574+
14561575
#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64)
14571576

14581577
// Scalar (WASM_SCALAR) / vectorized (ARM64) direct conv for depthwise convolution.
@@ -1468,12 +1587,12 @@ Return Value:
14681587
#endif
14691588

14701589
if (Dimensions == 2
1471-
&& Parameters->FilterCount == 1 && Parameters->InputChannels == 1
1472-
&& Parameters->KernelShape[0] == 3 && Parameters->KernelShape[1] == 3
1473-
&& Parameters->Padding[0] <= 1 && Parameters->Padding[1] <= 1
1474-
&& Parameters->Padding[2] <= 1 && Parameters->Padding[3] <= 1
1475-
&& depthwise_conv_stride_support_check
1476-
&& Parameters->DilationShape[0] == 1 && Parameters->DilationShape[1] == 1) {
1590+
&& Parameters->FilterCount == 1 && Parameters->InputChannels == 1
1591+
&& Parameters->KernelShape[0] == 3 && Parameters->KernelShape[1] == 3
1592+
&& Parameters->Padding[0] <= 1 && Parameters->Padding[1] <= 1
1593+
&& Parameters->Padding[2] <= 1 && Parameters->Padding[3] <= 1
1594+
&& depthwise_conv_stride_support_check
1595+
&& Parameters->DilationShape[0] == 1 && Parameters->DilationShape[1] == 1) {
14771596

14781597
*WorkingBufferSize = Parameters->InputShape[1] + 2;
14791598
Parameters->Algorithm = MlasConvAlgorithmDepthwise;

onnxruntime/core/mlas/lib/mlasi.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1638,8 +1638,6 @@ MlasFp32FromBits(
16381638
#endif
16391639

16401640
#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64)
1641-
1642-
16431641
void
16441642
MLASCALL
16451643
MlasConvDepthwiseFloat_CHW(
@@ -1652,6 +1650,13 @@ MlasConvDepthwiseFloat_CHW(
16521650

16531651
#endif
16541652

1653+
void
1654+
MlasConvDepthwiseWithMultiplierFloat_CHW(
1655+
const MLAS_CONV_PARAMETERS* Parameters,
1656+
const float* Input,
1657+
const float* Filter,
1658+
float* Output
1659+
);
16551660

16561661
//
16571662
// Define the missing ARM64 NEON intrinsic macros from arm64_neon.h that enable

onnxruntime/core/mlas/lib/sconv_nchw_kernel_neon.cpp renamed to onnxruntime/core/mlas/lib/sconv_nchw_depthwise_multiplier_1.cpp

Lines changed: 47 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,18 @@ Licensed under the MIT License.
66
77
Module Name:
88
9-
sconv_nchw_kernel_neon.cpp
9+
sconv_nchw_depthwise_multiplier_1.cpp
1010
1111
Abstract:
1212
13-
This module implements the single precision NCHW convolution kernels for ARM NEON.
13+
This module implements the single precision NCHW depthwise convolution kernels
14+
for depth multiplier 1.
1415
1516
--*/
1617

1718

1819
#include "mlasi.h"
19-
#include <arm_neon.h>
20+
#include <cassert>
2021

2122
MLAS_FORCEINLINE float DepthwiseSampleValue(
2223
const float* row,
@@ -50,7 +51,7 @@ MLAS_FORCEINLINE float DepthwiseAccumulateRowScalar(
5051
}
5152

5253
MLAS_FORCEINLINE void DepthwiseAccumulateRowVector(
53-
float32x4_t& acc,
54+
MLAS_FLOAT32X4& acc,
5455
const float* row,
5556
size_t base,
5657
float w0,
@@ -63,9 +64,9 @@ MLAS_FORCEINLINE void DepthwiseAccumulateRowVector(
6364
}
6465

6566
const float* r = row + base;
66-
const float32x4_t c0 = MlasLoadFloat32x4(r);
67-
const float32x4_t c1 = MlasLoadFloat32x4(r + 1);
68-
const float32x4_t c2 = MlasLoadFloat32x4(r + 2);
67+
const MLAS_FLOAT32X4 c0 = MlasLoadFloat32x4(r);
68+
const MLAS_FLOAT32X4 c1 = MlasLoadFloat32x4(r + 1);
69+
const MLAS_FLOAT32X4 c2 = MlasLoadFloat32x4(r + 2);
6970

7071
acc = MlasMultiplyAddFloat32x4(c0, w0, acc);
7172
acc = MlasMultiplyAddFloat32x4(c1, w1, acc);
@@ -107,12 +108,31 @@ MLAS_FORCEINLINE float DepthwiseComputeEdge(
107108
return acc;
108109
}
109110

110-
static void DepthwiseConv3x3Stride1PadLe1Neon(
111+
static
112+
void
113+
MlasConv2dSingleChannel_CHW_Kernel3x3_Pad01_Dilation1(
111114
const MLAS_CONV_PARAMETERS* Parameters,
112115
const float* Input,
113116
const float* Filter,
114117
float* Output
115-
)
118+
)
119+
/*++
120+
121+
Routine Description:
122+
123+
This routine is an inner kernel to compute convolution on one channel input with one filter channel.
124+
125+
Arguments:
126+
127+
Parameters - conv parameters calculated based on conv parameters like padding, strides, dilations, etc.
128+
129+
Input - input channel data start. Input is NCHW, so this pointer points to single H x W image data.
130+
131+
Filter - Whole filters are of F x CpG x FH x FW, this filter points to single FH x FW filter data.
132+
133+
Output - whole output are of N x F x OH x OW. This pointer points to single OH x OW output image data.
134+
135+
--*/
116136
{
117137
const size_t H = Parameters->InputShape[0];
118138
const size_t W = Parameters->InputShape[1];
@@ -185,14 +205,14 @@ static void DepthwiseConv3x3Stride1PadLe1Neon(
185205
}
186206

187207
const size_t base = static_cast<size_t>(iw);
188-
float32x4_t acc = MlasZeroFloat32x4();
208+
MLAS_FLOAT32X4 acc = MlasZeroFloat32x4();
189209

190210
DepthwiseAccumulateRowVector(acc, row0, base, w00, w01, w02);
191211
DepthwiseAccumulateRowVector(acc, row1, base, w10, w11, w12);
192212
DepthwiseAccumulateRowVector(acc, row2, base, w20, w21, w22);
193213

194214
if (accumulate_output) {
195-
const float32x4_t prev = MlasLoadFloat32x4(out_row + ow);
215+
const MLAS_FLOAT32X4 prev = MlasLoadFloat32x4(out_row + ow);
196216
acc = MlasMultiplyAddFloat32x4(prev, beta, acc);
197217
}
198218

@@ -230,35 +250,6 @@ static void DepthwiseConv3x3Stride1PadLe1Neon(
230250
}
231251
}
232252

233-
static
234-
void
235-
MlasConv2dSingleChannel_CHW_Kernel3x3_Pad01_Dilation1(
236-
const MLAS_CONV_PARAMETERS* Parameters,
237-
const float* Input,
238-
const float* Filter,
239-
float* Output
240-
)
241-
/*++
242-
243-
Routine Description:
244-
245-
This routine is an inner kernel to compute convolution on one channel input with one filter channel.
246-
247-
Arguments:
248-
249-
Parameters - conv parameters calculated based on conv parameters like padding, strides, dilations, etc.
250-
251-
Input - input channel data start. Input is NCHW, so this pointer points to single H x W image data.
252-
253-
Filter - Whole filters are of F x CpG x FH x FW, this filter points to single FH x FW filter data.
254-
255-
Output - whole output are of N x F x OH x OW. This pointer points to single OH x OW output image data.
256-
257-
--*/
258-
{
259-
DepthwiseConv3x3Stride1PadLe1Neon(Parameters, Input, Filter, Output);
260-
}
261-
262253
void MlasConvDepthwiseFloat_CHW(
263254
const MLAS_CONV_PARAMETERS* Parameters,
264255
const float* Input,
@@ -292,6 +283,22 @@ Routine Description:
292283
293284
--*/
294285
{
286+
assert(Parameters->Dimensions == 2);
287+
assert(Parameters->FilterCount == 1);
288+
assert(Parameters->InputChannels == 1);
289+
assert(Parameters->KernelShape[0] == 3);
290+
assert(Parameters->KernelShape[1] == 3);
291+
assert(Parameters->StrideShape[0] == 1);
292+
assert(Parameters->StrideShape[1] == 1);
293+
assert(Parameters->DilationShape[0] == 1);
294+
assert(Parameters->DilationShape[1] == 1);
295+
assert(Parameters->Padding[0] <= 1);
296+
assert(Parameters->Padding[1] <= 1);
297+
assert(Parameters->Padding[2] <= 1);
298+
assert(Parameters->Padding[3] <= 1);
299+
295300
MLAS_UNREFERENCED_PARAMETER(Zeros);
301+
302+
// Kernel dispatch
296303
MlasConv2dSingleChannel_CHW_Kernel3x3_Pad01_Dilation1(Parameters, Input, Filter, Output);
297304
}

0 commit comments

Comments
 (0)