@@ -37,35 +37,37 @@ void main() {
3737 const uint gidx = gl_GlobalInvocationID.x;
3838
3939 const uint oh = gl_GlobalInvocationID.y;
40- const uint batch_ic = gl_GlobalInvocationID.z;
40+ const uint batch = gl_GlobalInvocationID.z / p.IC;
41+ const uint ic = gl_GlobalInvocationID.z % p.IC;
4142
42- const uint batch = batch_ic / p.IC;
43- const uint ic = batch_ic % p.IC;
44-
45- const uint ksize = p.OW * ((p.KH > 1) ? p.KW : 1);
43+ const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1);
4644 const uint src_base = ic * p.offset_delta + batch * p.batch_offset;
47- const uint dst_base = ((batch * p.OH + oh) * p.OW) * p.CHW + ic * p.KW * p.KH;
45+ const uint dst_base = ((batch * p.OH + oh) * p.OW) * p.CHW + ic * (p.KW * p.KH);
46+ const int oh_s1 = int(oh) * p.s1;
47+ const uint base_idx = gidx * NUM_ITER;
4848
4949 [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
5050
51- const uint i = gidx * NUM_ITER + idx;
51+ uint i = base_idx + idx;
52+
5253 if (i >= p.pelements) {
53- continue ;
54+ break ;
5455 }
5556
56- const uint kx = i / ksize;
57- const uint ky = (i % ksize) / p.OW;
58- const uint ix = i % p.OW;
57+ const uint kx = i / ksize;
58+ const uint rem = i % ksize;
59+ const uint ky = rem / p.OW;
60+ const uint ix = rem % p.OW;
5961
60- const int iiw = int(ix * uint( p.s0)) + int(kx * uint( p.d0)) - p.p0;
61- const int iih = int(oh * uint(p.s1)) + int(ky * uint( p.d1)) - p.p1;
62+ int iiw = int(ix) * p.s0 + int(kx) * p.d0 - p.p0;
63+ int iih = oh_s1 + int(ky) * p.d1 - p.p1;
6264
6365 const uint dst_offset = dst_base + ix * p.CHW + ky * p.KW + kx;
6466
65- const bool valid = iih >= 0 && iih < int(p.IH) && iiw >= 0 && iiw < int(p.IW);
66- const uint src_offset = src_base + uint(iih) * p.IW + uint(iiw);
67-
68- data_d[dst_offset] = D_TYPE(valid ? data_a[src_offset] : 0. 0);
67+ data_d[dst_offset] = D_TYPE(( iih >= 0 && iih < int(p.IH) &&
68+ iiw >= 0 && iiw < int( p.IW))
69+ ? data_a[src_base + uint(iih) * p.IW + uint(iiw)]
70+ : 0);
6971 }
7072
7173}
0 commit comments