Skip to content

Commit 12aaeae

Browse files
committed
Vulkan: revert the order of the index calculation and bound check in conv_2d shader
1 parent f5ae689 commit 12aaeae

File tree

1 file changed

+8
-12
lines changed

1 file changed

+8
-12
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -229,16 +229,14 @@ void main() {
229229
uint32_t B_ly = r_offset + Ar;
230230
uint32_t B_lx = Ac;
231231
uint32_t K_idx = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/
232-
float val;
233-
if (K_idx >= K || CRS_idx_a >= CRS) {
234-
val = 0.0;
235-
} else {
236232
#ifdef TRANSPOSE
237-
uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + K_idx * p.nb02 + Cin_idx_a * p.nb03, K * CRS - 1);
233+
uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + K_idx * p.nb02 + Cin_idx_a * p.nb03, K * CRS - 1);
238234
#else
239-
uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + Cin_idx_a * p.nb02 + K_idx * p.nb03, K * CRS - 1);
235+
uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + Cin_idx_a * p.nb02 + K_idx * p.nb03, K * CRS - 1);
240236
#endif
241-
val = knl_data[knl_idx];
237+
float val = knl_data[knl_idx];
238+
if (K_idx >= K || CRS_idx_a >= CRS) {
239+
val = 0.0;
242240
}
243241
Ash[B_ly * Ash_stride + B_lx] = SHMEM_TYPE(val);
244242
}
@@ -286,18 +284,16 @@ void main() {
286284
uint32_t H_idx = OH_idx * p.s1 + KH_idx_b * p.d1 - p.p1;
287285
uint32_t W_idx = OW_idx * p.s0 + KW_idx_b * p.d0 - p.p0;
288286
#endif
289-
float val;
287+
uint32_t src_idx =
288+
min(max(W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13, 0), p.Cin * p.N * p.W * p.H - 1);
289+
float val = src_data[src_idx];
290290
if (CRS_idx_b >= CRS || NPQ_idx >= NPQ
291291
|| int32_t(H_idx) < 0 || H_idx >= p.H || int32_t(W_idx) < 0 || W_idx >= p.W
292292
#ifdef TRANSPOSE
293293
|| (H_idx_x_s1 - H_idx * p.s1 != 0) || (W_idx_x_s0 - W_idx * p.s0 != 0)
294294
#endif
295295
) {
296296
val = 0.0;
297-
} else {
298-
uint32_t src_idx =
299-
min(max(W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13, 0), p.Cin * p.N * p.W * p.H - 1);
300-
val = src_data[src_idx];
301297
}
302298
Bsh[B_ly * Bsh_stride + B_lx] = SHMEM_TYPE(val);
303299
}

0 commit comments

Comments
 (0)