Skip to content

Commit ec6ead4

Browse files
add_depthwise_conv3_unequal (#7704)
1 parent cda13fa commit ec6ead4

File tree

2 files changed

+69
-2
lines changed

2 files changed

+69
-2
lines changed

lite/backends/metal/metal_kernel/texture/ConvAddReluMetal.metal

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,68 @@ kernel void depthwise_conv_3x3(texture2d_array<ftype, access::sample> inTexture[
590590
outTexture.write(relu, gid.xy, gid.z);
591591
}
592592

593+
kernel void depthwise_conv_3x3_unequal(
594+
texture2d_array<ftype, access::sample> inTexture[[texture(0)]],
595+
texture2d_array<ftype, access::sample> biasTexture[[texture(1)]],
596+
texture2d_array<ftype, access::write> outTexture[[texture(2)]],
597+
constant MetalConvParam& param[[buffer(0)]],
598+
const device ftype* weights[[buffer(1)]],
599+
uint3 gid[[thread_position_in_grid]]) {
600+
if (gid.x >= outTexture.get_width() || gid.y >= outTexture.get_height() ||
601+
gid.z >= outTexture.get_array_size()) {
602+
return;
603+
}
604+
uint inc = inTexture.get_array_size();
605+
uint inw = inTexture.get_width();
606+
uint inh = inTexture.get_height();
607+
uint output_slice = gid.z / 2;
608+
609+
ushort2 stride = ushort2(param.strideX, param.strideY);
610+
ushort2 posInInput = ushort2(gid.xy) * stride + ushort2(param.offsetX, param.offsetY);
611+
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
612+
const uint kernelHXW = 9;
613+
uint weithTo = gid.z * kernelHXW * 4;
614+
615+
ftype4 output;
616+
if (param.hasAddOp) {
617+
output = get_bias(gid, param.addParam, biasTexture);
618+
}
619+
620+
ushort dilation_x = param.dilationX;
621+
ushort dilation_y = param.dilationY;
622+
float2 intput_sample[9];
623+
intput_sample[0] = float2(posInInput.x - dilation_x, posInInput.y - dilation_y);
624+
intput_sample[1] = float2(posInInput.x, posInInput.y - dilation_y);
625+
intput_sample[2] = float2(posInInput.x + dilation_x, posInInput.y - dilation_y);
626+
intput_sample[3] = float2(posInInput.x - dilation_x, posInInput.y);
627+
intput_sample[4] = float2(posInInput.x, posInInput.y);
628+
intput_sample[5] = float2(posInInput.x + dilation_x, posInInput.y);
629+
intput_sample[6] = float2(posInInput.x - dilation_x, posInInput.y + dilation_y);
630+
intput_sample[7] = float2(posInInput.x, posInInput.y + dilation_y);
631+
intput_sample[8] = float2(posInInput.x + dilation_x, posInInput.y + dilation_y);
632+
633+
ftype4 input;
634+
for (int j = 0; j < 9; ++j) {
635+
// if(intput_sample[j].x >= 0 && intput_sample[j].x < inw && intput_sample[j].y >= 0 &&
636+
// intput_sample[j].y < inh){
637+
input = inTexture.sample(sample, intput_sample[j], output_slice);
638+
if (gid.z % 2 == 0) {
639+
output.x += input.x * weights[weithTo + 0 * kernelHXW + j];
640+
output.y += input.x * weights[weithTo + 1 * kernelHXW + j];
641+
output.z += input.y * weights[weithTo + 2 * kernelHXW + j];
642+
output.w += input.y * weights[weithTo + 3 * kernelHXW + j];
643+
} else {
644+
output.x += input.z * weights[weithTo + 0 * kernelHXW + j];
645+
output.y += input.z * weights[weithTo + 1 * kernelHXW + j];
646+
output.z += input.w * weights[weithTo + 2 * kernelHXW + j];
647+
output.w += input.w * weights[weithTo + 3 * kernelHXW + j];
648+
}
649+
//}
650+
}
651+
ftype4 relu = activation(output, param.activationParam);
652+
outTexture.write(relu, gid.xy, gid.z);
653+
}
654+
593655
kernel void depthwise_conv_5x5(texture2d_array<ftype, access::sample> inTexture[[texture(0)]],
594656
texture2d_array<ftype, access::sample> biasTexture[[texture(1)]],
595657
texture2d_array<ftype, access::write> outTexture[[texture(2)]],

lite/kernels/metal/image_op/conv2d_image_compute.mm

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,8 +257,10 @@
257257
}
258258
#endif
259259
return "conv_3x3";
260-
} else {
260+
} else if ((input_c == (filter_c * param.groups)) && filter_n == input_c) {
261261
return "group_conv_3x3";
262+
} else {
263+
return "depthwise_conv_3x3_unequal";
262264
}
263265
} else if (filter_w == 1 && filter_h == 5) {
264266
return "conv_5x1";
@@ -426,7 +428,10 @@
426428
bool pad_when_one_ch =
427429
!(param.filter->dims()[1] == 1 && param.filter->dims()[0] == param.x->dims()[1]);
428430
filter_buffer_ = std::make_shared<MetalBuffer>(metal_context_, param.filter->dims());
429-
filter_buffer_->pad_when_one_channel_ = pad_when_one_ch;
431+
if (param.groups != 1 && param.filter->dims()[0] != param.x->dims()[1]) {
432+
filter_buffer_->pad_when_one_channel_ = false;
433+
} else
434+
filter_buffer_->pad_when_one_channel_ = pad_when_one_ch;
430435
filter_buffer_->CopyFromNCHW<float>(filter);
431436
}
432437

0 commit comments

Comments
 (0)