@@ -2145,20 +2145,34 @@ kernel void kernel_im2col(
2145
2145
uint3 tgpg[[threadgroups_per_grid]],
2146
2146
uint3 tpitg[[thread_position_in_threadgroup]],
2147
2147
uint3 ntg[[threads_per_threadgroup]]) {
2148
- const int32_t iiw = tgpig[2 ] * s0 + tpitg[2 ] * d0 - p0;
2149
- const int32_t iih = tgpig[1 ] * s1 + tpitg[1 ] * d1 - p1;
2148
+ // const int64_t IC = tgpg[0];
2149
+ const int64_t OH = tgpg[1 ];
2150
+ const int64_t OW = tgpg[2 ];
2150
2151
2151
- const int32_t offset_dst =
2152
- (tpitg[0 ] * tgpg[1 ] * tgpg[2 ] + tgpig[1 ] * tgpg[2 ] + tgpig[2 ]) * CHW +
2153
- (tgpig[0 ] * (ntg[1 ] * ntg[2 ]) + tpitg[1 ] * ntg[2 ] + tpitg[2 ]);
2152
+ // const int64_t N = ntg[0];
2153
+ const int64_t KH = ntg[1 ];
2154
+ const int64_t KW = ntg[2 ];
2155
+
2156
+ const int64_t in = tpitg[0 ];
2157
+ const int64_t ikh = tpitg[1 ];
2158
+ const int64_t ikw = tpitg[2 ];
2159
+
2160
+ const int64_t iic = tgpig[0 ];
2161
+ const int64_t ioh = tgpig[1 ];
2162
+ const int64_t iow = tgpig[2 ];
2163
+
2164
+ const int64_t iiw = iow*s0 + ikw*d0 - p0;
2165
+ const int64_t iih = ioh*s1 + ikh*d1 - p1;
2166
+
2167
+ const int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*CHW + (iic*(KH*KW) + ikh*KW + ikw);
2154
2168
2155
2169
device T * pdst = (device T *) (dst);
2156
2170
2157
2171
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
2158
2172
pdst[offset_dst] = 0 .0f ;
2159
2173
} else {
2160
- const int32_t offset_src = tpitg[ 0 ] * ofs0 + tgpig[ 0 ] * ofs1 ;
2161
- pdst[offset_dst] = x[offset_src + iih * IW + iiw ];
2174
+ const int64_t offset_src = in* ofs0 + iic*ofs1 + iih*IW + iiw ;
2175
+ pdst[offset_dst] = x[offset_src];
2162
2176
}
2163
2177
}
2164
2178
@@ -2209,25 +2223,25 @@ kernel void kernel_im2col_ext(
2209
2223
uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
2210
2224
uint3 tpitg[[thread_position_in_threadgroup]],
2211
2225
uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
2212
- const int32_t KHW = KH * KW; // KHW == ntg[1] * ntg[2], KW == ntg[2]
2226
+ const int64_t KHW = KH * KW; // KHW == ntg[1] * ntg[2], KW == ntg[2]
2213
2227
2214
- const int32_t d = tgpig[0 ] / CHW;
2215
- const int32_t chw = tgpig[0 ] % CHW;
2216
- const int32_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
2217
- const int32_t HW = tgpig[0 ] % KHW;
2228
+ const int64_t d = tgpig[0 ] / CHW;
2229
+ const int64_t chw = tgpig[0 ] % CHW;
2230
+ const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
2231
+ const int64_t HW = tgpig[0 ] % KHW;
2218
2232
2219
- const int32_t tpitg_0 = (d * ntg[0 ]) + tpitg[0 ];
2233
+ const int64_t tpitg_0 = (d * ntg[0 ]) + tpitg[0 ];
2220
2234
if (tpitg_0 >= N) {
2221
2235
return ;
2222
2236
}
2223
2237
2224
- const int32_t tpitg_1 = HW / KW;
2225
- const int32_t tpitg_2 = HW % KW;
2238
+ const int64_t tpitg_1 = HW / KW;
2239
+ const int64_t tpitg_2 = HW % KW;
2226
2240
2227
- const int32_t iiw = tgpig[2 ] * s0 + tpitg_2 * d0 - p0;
2228
- const int32_t iih = tgpig[1 ] * s1 + tpitg_1 * d1 - p1;
2241
+ const int64_t iiw = tgpig[2 ] * s0 + tpitg_2 * d0 - p0;
2242
+ const int64_t iih = tgpig[1 ] * s1 + tpitg_1 * d1 - p1;
2229
2243
2230
- const int32_t offset_dst =
2244
+ const int64_t offset_dst =
2231
2245
(tpitg_0 * tgpg[1 ] * tgpg[2 ] + tgpig[1 ] * tgpg[2 ] + tgpig[2 ]) * CHW +
2232
2246
(tgpig_0 * KHW + tpitg_1 * KW + tpitg_2);
2233
2247
@@ -2236,7 +2250,7 @@ kernel void kernel_im2col_ext(
2236
2250
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
2237
2251
pdst[offset_dst] = 0 .0f ;
2238
2252
} else {
2239
- const int32_t offset_src = tpitg_0 * ofs0 + tgpig_0 * ofs1;
2253
+ const int64_t offset_src = tpitg_0 * ofs0 + tgpig_0 * ofs1;
2240
2254
pdst[offset_dst] = x[offset_src + iih * IW + iiw];
2241
2255
}
2242
2256
}
0 commit comments