Skip to content

Commit f5ae689

Browse files
committed
Vulkan: add checking the push constants size limit and reuse conv2d_mm.comp for conv_transpose_2d operation
1 parent e029cda commit f5ae689

File tree

4 files changed

+63
-377
lines changed

4 files changed

+63
-377
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3603,26 +3603,25 @@ static void ggml_vk_load_shaders(vk_device& device) {
36033603
device, device->pipeline_##name##type_suffix[s], #name #type_suffix, \
36043604
name##type_suffix##spv_suffix##_len, name##type_suffix##spv_suffix##_data, "main", 3, \
36053605
sizeof(vk_op_##name##_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
3606+
#define CREATE_CONVS(spv_suffix) \
3607+
CREATE_CONV(conv2d, _f32, spv_suffix) \
3608+
CREATE_CONV(conv2d, _f16_f32, spv_suffix) \
3609+
if (device->properties.limits.maxPushConstantsSize >= sizeof(vk_op_conv_transpose_2d_push_constants)) { \
3610+
CREATE_CONV(conv_transpose_2d, _f32, spv_suffix) \
3611+
CREATE_CONV(conv_transpose_2d, _f16_f32, spv_suffix) \
3612+
}
36063613
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
36073614
if (device->coopmat2) {
3608-
CREATE_CONV(conv2d, _f32, _cm2)
3609-
CREATE_CONV(conv2d, _f16_f32, _cm2)
3610-
CREATE_CONV(conv_transpose_2d, _f32, _cm2)
3611-
CREATE_CONV(conv_transpose_2d, _f16_f32, _cm2)
3615+
CREATE_CONVS(_cm2)
36123616
} else
36133617
#endif
36143618
if (conv2d_UNROLL) {
3615-
CREATE_CONV(conv2d, _f32, _unroll)
3616-
CREATE_CONV(conv2d, _f16_f32, _unroll)
3617-
CREATE_CONV(conv_transpose_2d, _f32, _unroll)
3618-
CREATE_CONV(conv_transpose_2d, _f16_f32, _unroll)
3619+
CREATE_CONVS(_unroll)
36193620
} else {
3620-
CREATE_CONV(conv2d, _f32, )
3621-
CREATE_CONV(conv2d, _f16_f32, )
3622-
CREATE_CONV(conv_transpose_2d, _f32, )
3623-
CREATE_CONV(conv_transpose_2d, _f16_f32, )
3621+
CREATE_CONVS( )
36243622
}
36253623
#undef CREATE_CONV
3624+
#undef CREATE_CONVS
36263625
}
36273626

36283627
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
@@ -12722,6 +12721,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1272212721
// Op is disabled for Apple because it segfaults at pipeline create time on MoltenVK
1272312722
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
1272412723
const vk_device& device = ggml_vk_get_device(ctx->device);
12724+
if (op->op == GGML_OP_CONV_TRANSPOSE_2D && !device->pipeline_conv_transpose_2d_f32[0]) {
12725+
return false;
12726+
}
1272512727
// Channel-contiguous format is not supported yet.
1272612728
return ((op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
1272712729
op->src[1]->type == GGML_TYPE_F32 &&

ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
// shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j
1717
layout(binding = 0) readonly buffer A {
1818
A_TYPE knl_data[];
19-
}; // src0 - kernel: [KW, KH, Cin, Cout]
19+
}; // src0 - kernel: [KW, KH, Cin, Cout] for conv_2d, [KW, KH, Cout, Cin] for conv_transposed_2d
2020

2121
layout(binding = 1) readonly buffer B {
2222
B_TYPE src_data[];
@@ -66,6 +66,10 @@ layout(push_constant) uniform parameter {
6666
uint32_t KWKHmp; uint32_t KWKHL;
6767
uint32_t OWmp; uint32_t OWL;
6868
uint32_t OWOHmp; uint32_t OWOHL;
69+
#ifdef TRANSPOSE
70+
uint32_t s0mp; uint32_t s0L;
71+
uint32_t s1mp; uint32_t s1L;
72+
#endif
6973
}
7074

7175
p;
@@ -225,10 +229,16 @@ void main() {
225229
uint32_t B_ly = r_offset + Ar;
226230
uint32_t B_lx = Ac;
227231
uint32_t K_idx = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/
228-
uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + Cin_idx_a * p.nb02 + K_idx * p.nb03, K * CRS - 1);
229-
float val = knl_data[knl_idx];
232+
float val;
230233
if (K_idx >= K || CRS_idx_a >= CRS) {
231234
val = 0.0;
235+
} else {
236+
#ifdef TRANSPOSE
237+
uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + K_idx * p.nb02 + Cin_idx_a * p.nb03, K * CRS - 1);
238+
#else
239+
uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + Cin_idx_a * p.nb02 + K_idx * p.nb03, K * CRS - 1);
240+
#endif
241+
val = knl_data[knl_idx];
232242
}
233243
Ash[B_ly * Ash_stride + B_lx] = SHMEM_TYPE(val);
234244
}
@@ -267,13 +277,27 @@ void main() {
267277
KW_idx_b = CRS_remainder - KH_idx_b * p.KW;
268278
#endif
269279

280+
#ifdef TRANSPOSE
281+
uint32_t H_idx_x_s1 = OH_idx - KH_idx_b * p.d1 + p.p1;
282+
uint32_t W_idx_x_s0 = OW_idx - KW_idx_b * p.d0 + p.p0;
283+
uint32_t H_idx = fastdiv(H_idx_x_s1, p.s1mp, p.s1L);
284+
uint32_t W_idx = fastdiv(W_idx_x_s0, p.s0mp, p.s0L);
285+
#else
270286
uint32_t H_idx = OH_idx * p.s1 + KH_idx_b * p.d1 - p.p1;
271287
uint32_t W_idx = OW_idx * p.s0 + KW_idx_b * p.d0 - p.p0;
272-
uint32_t src_idx =
273-
min(max(W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13, 0), p.Cin * p.N * p.W * p.H - 1);
274-
float val = src_data[src_idx];
275-
if (CRS_idx_b >= CRS || NPQ_idx >= NPQ || H_idx < 0 || H_idx >= p.H || W_idx < 0 || W_idx >= p.W) {
288+
#endif
289+
float val;
290+
if (CRS_idx_b >= CRS || NPQ_idx >= NPQ
291+
|| int32_t(H_idx) < 0 || H_idx >= p.H || int32_t(W_idx) < 0 || W_idx >= p.W
292+
#ifdef TRANSPOSE
293+
|| (H_idx_x_s1 - H_idx * p.s1 != 0) || (W_idx_x_s0 - W_idx * p.s0 != 0)
294+
#endif
295+
) {
276296
val = 0.0;
297+
} else {
298+
uint32_t src_idx =
299+
min(max(W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13, 0), p.Cin * p.N * p.W * p.H - 1);
300+
val = src_data[src_idx];
277301
}
278302
Bsh[B_ly * Bsh_stride + B_lx] = SHMEM_TYPE(val);
279303
}

0 commit comments

Comments
 (0)