Skip to content

Commit e2c2a1a

Browse files
committed
im2col test 3
1 parent a6b70d4 commit e2c2a1a

File tree

1 file changed

+22
-19
lines changed

1 file changed

+22
-19
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -41,33 +41,36 @@ void main() {
4141
const uint ic = gl_GlobalInvocationID.z % p.IC;
4242

4343
const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1);
44+
const int oh_s1 = int(oh) * p.s1;
45+
const uint base_linear_idx = gidx * NUM_ITER;
4446
const uint src_base = ic * p.offset_delta + batch * p.batch_offset;
4547
const uint dst_base = ((batch * p.OH + oh) * p.OW) * p.CHW + ic * (p.KW * p.KH);
46-
const int oh_s1 = int(oh) * p.s1;
47-
const uint base_idx = gidx * NUM_ITER;
48-
49-
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
5048

51-
const uint i = base_idx + idx;
49+
uint current_kx = base_linear_idx / ksize;
50+
uint rem = base_linear_idx - (current_kx * ksize); // equivalent to init_val % ksize
51+
uint current_ky = rem / p.OW;
52+
uint current_ix = rem % p.OW;
5253

53-
if (i >= p.pelements) {
54-
break;
54+
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
55+
uint linear_idx = base_linear_idx + idx;
56+
if (linear_idx >= p.pelements) {
57+
continue;
5558
}
5659

57-
const uint kx = i / ksize;
58-
const uint rem = i % ksize;
59-
const uint ky = rem / p.OW;
60-
const uint ix = rem % p.OW;
61-
62-
const int iiw = int(ix) * p.s0 + int(kx) * p.d0 - p.p0;
63-
const int iih = oh_s1 + int(ky) * p.d1 - p.p1;
60+
int iiw = int(current_ix) * p.s0 + int(current_kx) * p.d0 - p.p0;
61+
int iih = oh_s1 + int(current_ky) * p.d1 - p.p1;
62+
uint dst_offset = dst_base + current_ix * p.CHW + current_ky * p.KW + current_kx;
6463

65-
const uint dst_offset = dst_base + ix * p.CHW + ky * p.KW + kx;
64+
bool valid = (iih >= 0 && iih < int(p.IH) && iiw >= 0 && iiw < int(p.IW));
65+
data_d[dst_offset] = D_TYPE(valid ? data_a[src_base + uint(iih) * p.IW + uint(iiw)] : 0);
6666

67-
data_d[dst_offset] = D_TYPE((iih >= 0 && iih < int(p.IH) &&
68-
iiw >= 0 && iiw < int(p.IW))
69-
? data_a[src_base + uint(iih) * p.IW + uint(iiw)]
70-
: 0);
67+
if (++current_ix == p.OW) {
68+
current_ix = 0;
69+
if (++current_ky == (ksize / p.OW)) {
70+
current_ky = 0;
71+
current_kx++;
72+
}
73+
}
7174
}
7275

7376
}

0 commit comments

Comments
 (0)