@@ -590,6 +590,68 @@ kernel void depthwise_conv_3x3(texture2d_array<ftype, access::sample> inTexture[
590
590
outTexture.write (relu, gid.xy , gid.z );
591
591
}
592
592
593
+ kernel void depthwise_conv_3x3_unequal (
594
+ texture2d_array<ftype, access::sample> inTexture[[texture(0 )]],
595
+ texture2d_array<ftype, access::sample> biasTexture[[texture(1 )]],
596
+ texture2d_array<ftype, access::write> outTexture[[texture(2 )]],
597
+ constant MetalConvParam& param[[buffer(0 )]],
598
+ const device ftype* weights[[buffer(1 )]],
599
+ uint3 gid[[thread_position_in_grid]]) {
600
+ if (gid.x >= outTexture.get_width () || gid.y >= outTexture.get_height () ||
601
+ gid.z >= outTexture.get_array_size ()) {
602
+ return ;
603
+ }
604
+ uint inc = inTexture.get_array_size ();
605
+ uint inw = inTexture.get_width ();
606
+ uint inh = inTexture.get_height ();
607
+ uint output_slice = gid.z / 2 ;
608
+
609
+ ushort2 stride = ushort2 (param.strideX , param.strideY );
610
+ ushort2 posInInput = ushort2 (gid.xy ) * stride + ushort2 (param.offsetX , param.offsetY );
611
+ constexpr sampler sample (coord::pixel, filter::nearest, address::clamp_to_zero);
612
+ const uint kernelHXW = 9 ;
613
+ uint weithTo = gid.z * kernelHXW * 4 ;
614
+
615
+ ftype4 output;
616
+ if (param.hasAddOp ) {
617
+ output = get_bias (gid, param.addParam , biasTexture);
618
+ }
619
+
620
+ ushort dilation_x = param.dilationX ;
621
+ ushort dilation_y = param.dilationY ;
622
+ float2 intput_sample[9 ];
623
+ intput_sample[0 ] = float2 (posInInput.x - dilation_x, posInInput.y - dilation_y);
624
+ intput_sample[1 ] = float2 (posInInput.x , posInInput.y - dilation_y);
625
+ intput_sample[2 ] = float2 (posInInput.x + dilation_x, posInInput.y - dilation_y);
626
+ intput_sample[3 ] = float2 (posInInput.x - dilation_x, posInInput.y );
627
+ intput_sample[4 ] = float2 (posInInput.x , posInInput.y );
628
+ intput_sample[5 ] = float2 (posInInput.x + dilation_x, posInInput.y );
629
+ intput_sample[6 ] = float2 (posInInput.x - dilation_x, posInInput.y + dilation_y);
630
+ intput_sample[7 ] = float2 (posInInput.x , posInInput.y + dilation_y);
631
+ intput_sample[8 ] = float2 (posInInput.x + dilation_x, posInInput.y + dilation_y);
632
+
633
+ ftype4 input;
634
+ for (int j = 0 ; j < 9 ; ++j) {
635
+ // if(intput_sample[j].x >= 0 && intput_sample[j].x < inw && intput_sample[j].y >= 0 &&
636
+ // intput_sample[j].y < inh){
637
+ input = inTexture.sample (sample, intput_sample[j], output_slice);
638
+ if (gid.z % 2 == 0 ) {
639
+ output.x += input.x * weights[weithTo + 0 * kernelHXW + j];
640
+ output.y += input.x * weights[weithTo + 1 * kernelHXW + j];
641
+ output.z += input.y * weights[weithTo + 2 * kernelHXW + j];
642
+ output.w += input.y * weights[weithTo + 3 * kernelHXW + j];
643
+ } else {
644
+ output.x += input.z * weights[weithTo + 0 * kernelHXW + j];
645
+ output.y += input.z * weights[weithTo + 1 * kernelHXW + j];
646
+ output.z += input.w * weights[weithTo + 2 * kernelHXW + j];
647
+ output.w += input.w * weights[weithTo + 3 * kernelHXW + j];
648
+ }
649
+ // }
650
+ }
651
+ ftype4 relu = activation (output, param.activationParam );
652
+ outTexture.write (relu, gid.xy , gid.z );
653
+ }
654
+
593
655
kernel void depthwise_conv_5x5 (texture2d_array<ftype, access::sample> inTexture[[texture(0 )]],
594
656
texture2d_array<ftype, access::sample> biasTexture[[texture(1 )]],
595
657
texture2d_array<ftype, access::write> outTexture[[texture(2 )]],
0 commit comments