@@ -892,6 +892,77 @@ Return Value:
892892
893893#endif
894894
895+ void
896+ MlasDepthwiseWithMultiplierThreaded (
897+ void * Context,
898+ ptrdiff_t Index
899+ )
900+ /* ++
901+
902+ Routine Description:
903+
904+ This routine is invoked from a worker thread to execute a segment of a
905+ convolution operation.
906+
907+ If using this, the entire convolution operation is parallelized on the
908+ (batch size * group count) parameter and this routine has logic to
909+ perform a specific thread's shard of the entire Convolution operation.
910+
911+ Arguments:
912+
913+ Context - Supplies the pointer to the context for the threaded operation.
914+
915+ Index - Supplies the current index of the threaded operation.
916+
917+ Return Value:
918+
919+ None.
920+
921+ --*/
922+ {
923+ MLAS_CONV_WORK_BLOCK* WorkBlock = (MLAS_CONV_WORK_BLOCK*)Context;
924+
925+ const MLAS_CONV_PARAMETERS* Parameters = WorkBlock->Parameters ;
926+ const size_t GroupCount = Parameters->GroupCount ;
927+ const size_t BatchGroupCount = Parameters->BatchCount * GroupCount;
928+
929+ size_t BatchGroupStart;
930+ size_t BatchGroupRemaining;
931+
932+ MlasPartitionWork (Index, WorkBlock->TargetThreadCount , BatchGroupCount,
933+ &BatchGroupStart, &BatchGroupRemaining);
934+
935+ size_t BatchGroupEnd = BatchGroupStart + BatchGroupRemaining;
936+
937+ const size_t FilterCount = Parameters->FilterCount ;
938+ const size_t OutputSize = Parameters->OutputSize ;
939+ const size_t K = Parameters->K ;
940+
941+ const size_t InputGroupSize = Parameters->InputChannels * Parameters->InputSize ;
942+ const size_t OutputGroupSize = FilterCount * OutputSize;
943+ const size_t FilterGroupSize = FilterCount * K;
944+
945+ const float * input = WorkBlock->Input + BatchGroupStart * InputGroupSize;
946+ float * output = WorkBlock->Output + BatchGroupStart * OutputGroupSize;
947+
948+ for (size_t bg = BatchGroupStart; bg < BatchGroupEnd; bg++) {
949+ size_t group = bg % GroupCount;
950+
951+ const float * filter = WorkBlock->Filter + group * FilterGroupSize;
952+ const float * bias = WorkBlock->Bias ;
953+ if (bias != nullptr ) {
954+ bias += group * FilterCount;
955+ }
956+
957+ MlasConvDepthwiseWithMultiplierFloat_CHW (Parameters, input, filter, output);
958+ MlasActivation (Parameters->Activation , output, bias, FilterCount,
959+ OutputSize, OutputSize);
960+
961+ input += InputGroupSize;
962+ output += OutputGroupSize;
963+ }
964+ }
965+
895966inline
896967bool
897968MlasConvTryMultithread (
@@ -1106,7 +1177,6 @@ Return Value:
11061177 return ;
11071178 }
11081179
1109-
11101180#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64)
11111181
11121182 if (Algorithm == MlasConvAlgorithmDepthwise && ((BatchCount > 1 ) || (GroupCount > 1 ))) {
@@ -1135,6 +1205,28 @@ Return Value:
11351205
11361206#endif
11371207
1208+ if (Algorithm == MlasConvAlgorithmDepthwiseWithMultiplier && ((BatchCount > 1 ) || (GroupCount > 1 ))) {
1209+ const size_t BatchGroupCount = BatchCount * GroupCount;
1210+ ptrdiff_t TargetThreadCount = MlasGetMaximumThreadCount (ThreadPool);
1211+
1212+ if (static_cast <size_t >(TargetThreadCount) >= BatchGroupCount) {
1213+ TargetThreadCount = static_cast <ptrdiff_t >(BatchGroupCount);
1214+ }
1215+
1216+ MLAS_CONV_WORK_BLOCK WorkBlock;
1217+ WorkBlock.Parameters = Parameters;
1218+ WorkBlock.Input = Input;
1219+ WorkBlock.Filter = Filter;
1220+ WorkBlock.Bias = Bias;
1221+ WorkBlock.WorkingBuffer = nullptr ;
1222+ WorkBlock.Output = Output;
1223+ WorkBlock.TargetThreadCount = TargetThreadCount;
1224+
1225+ MlasExecuteThreaded (MlasDepthwiseWithMultiplierThreaded, &WorkBlock,
1226+ TargetThreadCount, ThreadPool);
1227+ return ;
1228+ }
1229+
11381230 //
11391231 // Iterate over each batch and group.
11401232 //
@@ -1209,6 +1301,13 @@ Return Value:
12091301
12101302#endif
12111303
1304+ case MlasConvAlgorithmDepthwiseWithMultiplier:
1305+ {
1306+ MlasConvDepthwiseWithMultiplierFloat_CHW (Parameters, Input, filter, Output);
1307+ MlasActivation (Parameters->Activation , Output, bias, FilterCount, OutputSize, OutputSize);
1308+ break ;
1309+ }
1310+
12121311 case MlasConvAlgorithmExpandThenGemmSegmented:
12131312 {
12141313 //
@@ -1453,6 +1552,26 @@ Return Value:
14531552
14541553 } else {
14551554
1555+ // Commonly found in MobileNet like models, where the depthwise convolution with
1556+ // depth_multiplier = 2 is used together with 7x7 kernel shape, stride = 2 and dilation = 1.
1557+ // This is a very specific scenario, but it is worth to have a specialized kernel for it given
1558+ // the popularity of MobileNet models.
1559+ if (Dimensions == 2
1560+ // depthwise convolution
1561+ && Parameters->GroupCount > 1
1562+ && Parameters->InputChannels == 1
1563+ // depth_multiplier = 2
1564+ && Parameters->FilterCount == 2
1565+ // current scope for specialized kernel is for the 7x7 kernel shape
1566+ && Parameters->KernelShape [0 ] == 7 && Parameters->KernelShape [1 ] == 7
1567+ // keep this specialized kernel only for stride = 2x2
1568+ && Parameters->StrideShape [0 ] == 2 && Parameters->StrideShape [1 ] == 2
1569+ // keep this specialized kernel only for dilation = 1x1
1570+ && Parameters->DilationShape [0 ] == 1 && Parameters->DilationShape [1 ] == 1 ) {
1571+ Parameters->Algorithm = MlasConvAlgorithmDepthwiseWithMultiplier;
1572+ return ;
1573+ }
1574+
14561575#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64)
14571576
14581577 // Scalar (WASM_SCALAR) / vectorized (ARM64) direct conv for depthwise convolution.
@@ -1468,12 +1587,12 @@ Return Value:
14681587 #endif
14691588
14701589 if (Dimensions == 2
1471- && Parameters->FilterCount == 1 && Parameters->InputChannels == 1
1472- && Parameters->KernelShape [0 ] == 3 && Parameters->KernelShape [1 ] == 3
1473- && Parameters->Padding [0 ] <= 1 && Parameters->Padding [1 ] <= 1
1474- && Parameters->Padding [2 ] <= 1 && Parameters->Padding [3 ] <= 1
1475- && depthwise_conv_stride_support_check
1476- && Parameters->DilationShape [0 ] == 1 && Parameters->DilationShape [1 ] == 1 ) {
1590+ && Parameters->FilterCount == 1 && Parameters->InputChannels == 1
1591+ && Parameters->KernelShape [0 ] == 3 && Parameters->KernelShape [1 ] == 3
1592+ && Parameters->Padding [0 ] <= 1 && Parameters->Padding [1 ] <= 1
1593+ && Parameters->Padding [2 ] <= 1 && Parameters->Padding [3 ] <= 1
1594+ && depthwise_conv_stride_support_check
1595+ && Parameters->DilationShape [0 ] == 1 && Parameters->DilationShape [1 ] == 1 ) {
14771596
14781597 *WorkingBufferSize = Parameters->InputShape [1 ] + 2 ;
14791598 Parameters->Algorithm = MlasConvAlgorithmDepthwise;
0 commit comments