Skip to content

Commit 611fabd

Browse files
pminevggerganov
authored andcommitted
metal : fox offset integer overflows in im2col (ggml/1015)
-- While running StableDiffusion.cpp locally with Metal some offsets overflow and results in incorrect calculations
1 parent 12b0ad9 commit 611fabd

File tree

1 file changed

+33
-19
lines changed

1 file changed

+33
-19
lines changed

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2145,20 +2145,34 @@ kernel void kernel_im2col(
21452145
uint3 tgpg[[threadgroups_per_grid]],
21462146
uint3 tpitg[[thread_position_in_threadgroup]],
21472147
uint3 ntg[[threads_per_threadgroup]]) {
2148-
const int32_t iiw = tgpig[2] * s0 + tpitg[2] * d0 - p0;
2149-
const int32_t iih = tgpig[1] * s1 + tpitg[1] * d1 - p1;
2148+
// const int64_t IC = tgpg[0];
2149+
const int64_t OH = tgpg[1];
2150+
const int64_t OW = tgpg[2];
21502151

2151-
const int32_t offset_dst =
2152-
(tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
2153-
(tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
2152+
// const int64_t N = ntg[0];
2153+
const int64_t KH = ntg[1];
2154+
const int64_t KW = ntg[2];
2155+
2156+
const int64_t in = tpitg[0];
2157+
const int64_t ikh = tpitg[1];
2158+
const int64_t ikw = tpitg[2];
2159+
2160+
const int64_t iic = tgpig[0];
2161+
const int64_t ioh = tgpig[1];
2162+
const int64_t iow = tgpig[2];
2163+
2164+
const int64_t iiw = iow*s0 + ikw*d0 - p0;
2165+
const int64_t iih = ioh*s1 + ikh*d1 - p1;
2166+
2167+
const int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*CHW + (iic*(KH*KW) + ikh*KW + ikw);
21542168

21552169
device T * pdst = (device T *) (dst);
21562170

21572171
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
21582172
pdst[offset_dst] = 0.0f;
21592173
} else {
2160-
const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
2161-
pdst[offset_dst] = x[offset_src + iih * IW + iiw];
2174+
const int64_t offset_src = in*ofs0 + iic*ofs1 + iih*IW + iiw;
2175+
pdst[offset_dst] = x[offset_src];
21622176
}
21632177
}
21642178

@@ -2209,25 +2223,25 @@ kernel void kernel_im2col_ext(
22092223
uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
22102224
uint3 tpitg[[thread_position_in_threadgroup]],
22112225
uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
2212-
const int32_t KHW = KH * KW; // KHW == ntg[1] * ntg[2], KW == ntg[2]
2226+
const int64_t KHW = KH * KW; // KHW == ntg[1] * ntg[2], KW == ntg[2]
22132227

2214-
const int32_t d = tgpig[0] / CHW;
2215-
const int32_t chw = tgpig[0] % CHW;
2216-
const int32_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
2217-
const int32_t HW = tgpig[0] % KHW;
2228+
const int64_t d = tgpig[0] / CHW;
2229+
const int64_t chw = tgpig[0] % CHW;
2230+
const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
2231+
const int64_t HW = tgpig[0] % KHW;
22182232

2219-
const int32_t tpitg_0 = (d * ntg[0]) + tpitg[0];
2233+
const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
22202234
if (tpitg_0 >= N) {
22212235
return;
22222236
}
22232237

2224-
const int32_t tpitg_1 = HW / KW;
2225-
const int32_t tpitg_2 = HW % KW;
2238+
const int64_t tpitg_1 = HW / KW;
2239+
const int64_t tpitg_2 = HW % KW;
22262240

2227-
const int32_t iiw = tgpig[2] * s0 + tpitg_2 * d0 - p0;
2228-
const int32_t iih = tgpig[1] * s1 + tpitg_1 * d1 - p1;
2241+
const int64_t iiw = tgpig[2] * s0 + tpitg_2 * d0 - p0;
2242+
const int64_t iih = tgpig[1] * s1 + tpitg_1 * d1 - p1;
22292243

2230-
const int32_t offset_dst =
2244+
const int64_t offset_dst =
22312245
(tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
22322246
(tgpig_0 * KHW + tpitg_1 * KW + tpitg_2);
22332247

@@ -2236,7 +2250,7 @@ kernel void kernel_im2col_ext(
22362250
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
22372251
pdst[offset_dst] = 0.0f;
22382252
} else {
2239-
const int32_t offset_src = tpitg_0 * ofs0 + tgpig_0 * ofs1;
2253+
const int64_t offset_src = tpitg_0 * ofs0 + tgpig_0 * ofs1;
22402254
pdst[offset_dst] = x[offset_src + iih * IW + iiw];
22412255
}
22422256
}

0 commit comments

Comments
 (0)