@@ -41,33 +41,36 @@ void main() {
4141 const uint ic = gl_GlobalInvocationID.z % p.IC;
4242
4343 const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1);
44+ const int oh_s1 = int(oh) * p.s1;
45+ const uint base_linear_idx = gidx * NUM_ITER;
4446 const uint src_base = ic * p.offset_delta + batch * p.batch_offset;
4547 const uint dst_base = ((batch * p.OH + oh) * p.OW) * p.CHW + ic * (p.KW * p.KH);
46- const int oh_s1 = int(oh) * p.s1;
47- const uint base_idx = gidx * NUM_ITER;
48-
49- [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
5048
51- const uint i = base_idx + idx;
49+ uint current_kx = base_linear_idx / ksize;
50+ uint rem = base_linear_idx - (current_kx * ksize); // equivalent to init_val % ksize
51+ uint current_ky = rem / p.OW;
52+ uint current_ix = rem % p.OW;
5253
53- if (i >= p.pelements) {
54- break;
54+ [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
55+ uint linear_idx = base_linear_idx + idx;
56+ if (linear_idx >= p.pelements) {
57+ continue;
5558 }
5659
57- const uint kx = i / ksize;
58- const uint rem = i % ksize;
59- const uint ky = rem / p.OW;
60- const uint ix = rem % p.OW;
61-
62- const int iiw = int(ix) * p.s0 + int(kx) * p.d0 - p.p0;
63- const int iih = oh_s1 + int(ky) * p.d1 - p.p1;
60+ int iiw = int(current_ix) * p.s0 + int(current_kx) * p.d0 - p.p0;
61+ int iih = oh_s1 + int(current_ky) * p.d1 - p.p1;
62+ uint dst_offset = dst_base + current_ix * p.CHW + current_ky * p.KW + current_kx;
6463
65- const uint dst_offset = dst_base + ix * p.CHW + ky * p.KW + kx;
64+ bool valid = (iih >= 0 && iih < int(p.IH) && iiw >= 0 && iiw < int(p.IW));
65+ data_d[dst_offset] = D_TYPE(valid ? data_a[src_base + uint(iih) * p.IW + uint(iiw)] : 0);
6666
67- data_d[dst_offset] = D_TYPE((iih >= 0 && iih < int(p.IH) &&
68- iiw >= 0 && iiw < int(p.IW))
69- ? data_a[src_base + uint(iih) * p.IW + uint(iiw)]
70- : 0);
67+ if (++current_ix == p.OW) {
68+ current_ix = 0;
69+ if (++current_ky == (ksize / p.OW)) {
70+ current_ky = 0;
71+ current_kx++;
72+ }
73+ }
7174 }
7275
7376}
0 commit comments