@@ -782,6 +782,14 @@ kernel void kernel_silu_4(
782782 dst[tpig] = x / (1 .0f + exp (-x));
783783}
784784
785+ kernel void kernel_elu (
786+ device const float * src0,
787+ device float * dst,
788+ uint tpig[[thread_position_in_grid]]) {
789+ device const float & x = src0[tpig];
790+ dst[tpig] = (x > 0 .0f ) ? x : (exp (x) - 1 .0f );
791+ }
792+
785793kernel void kernel_sqr (
786794 device const float * src0,
787795 device float * dst,
@@ -2137,20 +2145,34 @@ kernel void kernel_im2col(
21372145 uint3 tgpg[[threadgroups_per_grid]],
21382146 uint3 tpitg[[thread_position_in_threadgroup]],
21392147 uint3 ntg[[threads_per_threadgroup]]) {
2140- const int32_t iiw = tgpig[2 ] * s0 + tpitg[2 ] * d0 - p0;
2141- const int32_t iih = tgpig[1 ] * s1 + tpitg[1 ] * d1 - p1;
2148+ // const int64_t IC = tgpg[0];
2149+ const int64_t OH = tgpg[1 ];
2150+ const int64_t OW = tgpg[2 ];
21422151
2143- const int32_t offset_dst =
2144- (tpitg[0 ] * tgpg[1 ] * tgpg[2 ] + tgpig[1 ] * tgpg[2 ] + tgpig[2 ]) * CHW +
2145- (tgpig[0 ] * (ntg[1 ] * ntg[2 ]) + tpitg[1 ] * ntg[2 ] + tpitg[2 ]);
2152+ // const int64_t N = ntg[0];
2153+ const int64_t KH = ntg[1 ];
2154+ const int64_t KW = ntg[2 ];
2155+
2156+ const int64_t in = tpitg[0 ];
2157+ const int64_t ikh = tpitg[1 ];
2158+ const int64_t ikw = tpitg[2 ];
2159+
2160+ const int64_t iic = tgpig[0 ];
2161+ const int64_t ioh = tgpig[1 ];
2162+ const int64_t iow = tgpig[2 ];
2163+
2164+ const int64_t iiw = iow*s0 + ikw*d0 - p0;
2165+ const int64_t iih = ioh*s1 + ikh*d1 - p1;
2166+
2167+ const int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*CHW + (iic*(KH*KW) + ikh*KW + ikw);
21462168
21472169 device T * pdst = (device T *) (dst);
21482170
21492171 if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
21502172 pdst[offset_dst] = 0 .0f ;
21512173 } else {
2152- const int32_t offset_src = tpitg[ 0 ] * ofs0 + tgpig[ 0 ] * ofs1 ;
2153- pdst[offset_dst] = x[offset_src + iih * IW + iiw ];
2174+ const int64_t offset_src = in* ofs0 + iic*ofs1 + iih*IW + iiw ;
2175+ pdst[offset_dst] = x[offset_src];
21542176 }
21552177}
21562178
@@ -2201,25 +2223,25 @@ kernel void kernel_im2col_ext(
22012223 uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
22022224 uint3 tpitg[[thread_position_in_threadgroup]],
22032225 uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
2204- const int32_t KHW = KH * KW; // KHW == ntg[1] * ntg[2], KW == ntg[2]
2226+ const int64_t KHW = KH * KW; // KHW == ntg[1] * ntg[2], KW == ntg[2]
22052227
2206- const int32_t d = tgpig[0 ] / CHW;
2207- const int32_t chw = tgpig[0 ] % CHW;
2208- const int32_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
2209- const int32_t HW = tgpig[0 ] % KHW;
2228+ const int64_t d = tgpig[0 ] / CHW;
2229+ const int64_t chw = tgpig[0 ] % CHW;
2230+ const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
2231+ const int64_t HW = tgpig[0 ] % KHW;
22102232
2211- const int32_t tpitg_0 = (d * ntg[0 ]) + tpitg[0 ];
2233+ const int64_t tpitg_0 = (d * ntg[0 ]) + tpitg[0 ];
22122234 if (tpitg_0 >= N) {
22132235 return ;
22142236 }
22152237
2216- const int32_t tpitg_1 = HW / KW;
2217- const int32_t tpitg_2 = HW % KW;
2238+ const int64_t tpitg_1 = HW / KW;
2239+ const int64_t tpitg_2 = HW % KW;
22182240
2219- const int32_t iiw = tgpig[2 ] * s0 + tpitg_2 * d0 - p0;
2220- const int32_t iih = tgpig[1 ] * s1 + tpitg_1 * d1 - p1;
2241+ const int64_t iiw = tgpig[2 ] * s0 + tpitg_2 * d0 - p0;
2242+ const int64_t iih = tgpig[1 ] * s1 + tpitg_1 * d1 - p1;
22212243
2222- const int32_t offset_dst =
2244+ const int64_t offset_dst =
22232245 (tpitg_0 * tgpg[1 ] * tgpg[2 ] + tgpig[1 ] * tgpg[2 ] + tgpig[2 ]) * CHW +
22242246 (tgpig_0 * KHW + tpitg_1 * KW + tpitg_2);
22252247
@@ -2228,7 +2250,7 @@ kernel void kernel_im2col_ext(
22282250 if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
22292251 pdst[offset_dst] = 0 .0f ;
22302252 } else {
2231- const int32_t offset_src = tpitg_0 * ofs0 + tgpig_0 * ofs1;
2253+ const int64_t offset_src = tpitg_0 * ofs0 + tgpig_0 * ofs1;
22322254 pdst[offset_dst] = x[offset_src + iih * IW + iiw];
22332255 }
22342256}
0 commit comments