Skip to content

Commit 0cf756e

Browse files
committed
vulkan: support SET_ROWS
Add variants of the copy_to_quant shader that do the SET_ROWS operation. Change these shaders to spread the work across the workgroup. The memory access pattern is probably not great (one thread per quant block), but should be fine for now.
1 parent b8eeb87 commit 0cf756e

File tree

3 files changed

+154
-24
lines changed

3 files changed

+154
-24
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 95 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,7 @@ struct vk_device_struct {
437437
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16;
438438
vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
439439
vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
440+
vk_pipeline pipeline_set_rows[GGML_TYPE_COUNT];
440441
vk_pipeline pipeline_norm_f32;
441442
vk_pipeline pipeline_group_norm_f32;
442443
vk_pipeline pipeline_rms_norm_f32;
@@ -2738,19 +2739,41 @@ static void ggml_vk_load_shaders(vk_device& device) {
27382739
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
27392740

27402741
if (device->float_controls_rte_fp16) {
2741-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
2742-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
2743-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1);
2744-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1);
2745-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
2746-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
2742+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2743+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2744+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2745+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2746+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2747+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
27472748
} else {
2748-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
2749-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
2750-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1);
2751-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1);
2752-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
2753-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
2749+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2750+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2751+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2752+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2753+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2754+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2755+
}
2756+
2757+
if (device->float_controls_rte_fp16) {
2758+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F32], "set_rows_f32", set_rows_f32_rte_len, set_rows_f32_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {32, 1, 1}, {}, 1);
2759+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F16], "set_rows_f16", set_rows_f16_rte_len, set_rows_f16_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {32, 1, 1}, {}, 1);
2760+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_BF16], "set_rows_bf16", set_rows_bf16_rte_len, set_rows_bf16_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {32, 1, 1}, {}, 1);
2761+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_0], "set_rows_q4_0", set_rows_q4_0_rte_len, set_rows_q4_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {32, 1, 1}, {}, 1);
2762+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_1], "set_rows_q4_1", set_rows_q4_1_rte_len, set_rows_q4_1_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {32, 1, 1}, {}, 1);
2763+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_0], "set_rows_q5_0", set_rows_q5_0_rte_len, set_rows_q5_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {32, 1, 1}, {}, 1);
2764+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_1], "set_rows_q5_1", set_rows_q5_1_rte_len, set_rows_q5_1_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {32, 1, 1}, {}, 1);
2765+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q8_0], "set_rows_q8_0", set_rows_q8_0_rte_len, set_rows_q8_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {32, 1, 1}, {}, 1);
2766+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_IQ4_NL], "set_rows_iq4_nl", set_rows_iq4_nl_rte_len, set_rows_iq4_nl_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {32, 1, 1}, {}, 1);
2767+
} else {
2768+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F32], "set_rows_f32", set_rows_f32_len, set_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {32, 1, 1}, {}, 1);
2769+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F16], "set_rows_f16", set_rows_f16_len, set_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {32, 1, 1}, {}, 1);
2770+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_BF16], "set_rows_bf16", set_rows_bf16_len, set_rows_bf16_data, "main", 3, sizeof(vk_op_binary_push_constants), {32, 1, 1}, {}, 1);
2771+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_0], "set_rows_q4_0", set_rows_q4_0_len, set_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {32, 1, 1}, {}, 1);
2772+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_1], "set_rows_q4_1", set_rows_q4_1_len, set_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {32, 1, 1}, {}, 1);
2773+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_0], "set_rows_q5_0", set_rows_q5_0_len, set_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {32, 1, 1}, {}, 1);
2774+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_1], "set_rows_q5_1", set_rows_q5_1_len, set_rows_q5_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {32, 1, 1}, {}, 1);
2775+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q8_0], "set_rows_q8_0", set_rows_q8_0_len, set_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {32, 1, 1}, {}, 1);
2776+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_IQ4_NL], "set_rows_iq4_nl", set_rows_iq4_nl_len, set_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {32, 1, 1}, {}, 1);
27542777
}
27552778

27562779
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], "cpy_q4_0_f32", cpy_q4_0_f32_len, cpy_q4_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
@@ -6516,6 +6539,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
65166539
case GGML_OP_CONT:
65176540
case GGML_OP_DUP:
65186541
return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type);
6542+
case GGML_OP_SET_ROWS:
6543+
return ctx->device->pipeline_set_rows[dst->type];
65196544
case GGML_OP_SILU_BACK:
65206545
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
65216546
return ctx->device->pipeline_silu_back_f32;
@@ -6754,6 +6779,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
67546779
case GGML_OP_RMS_NORM:
67556780
case GGML_OP_CONV_2D_DW:
67566781
case GGML_OP_IM2COL:
6782+
case GGML_OP_SET_ROWS:
67576783
return true;
67586784
default:
67596785
return false;
@@ -7067,6 +7093,12 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
70677093
ne *= ggml_type_size(src0->type) / 2;
70687094
}
70697095
}
7096+
// copy_to_quant has block size of 32, and each thread does QUANT_K elements.
7097+
// Splitting into 512x512xZ wouldn't work well since each workgroup does 1024 elements.
7098+
// So divide by block size here before splitting into 512x512 groups.
7099+
if (op == GGML_OP_CPY && !ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
7100+
ne = CEIL_DIV(ne, ggml_blck_size(dst->type));
7101+
}
70707102
if (ne > 262144) {
70717103
elements = { 512, 512, CEIL_DIV(ne, 262144) };
70727104
} else if (ne > 512) {
@@ -7075,6 +7107,19 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
70757107
elements = { ne, 1, 1 };
70767108
}
70777109
} break;
7110+
case GGML_OP_SET_ROWS:
7111+
{
7112+
uint32_t ne = ggml_nelements(src0);
7113+
ne = CEIL_DIV(ne, ggml_blck_size(dst->type));
7114+
if (ne > 262144) {
7115+
elements = { 512, 512, CEIL_DIV(ne, 262144) };
7116+
} else if (ne > 512) {
7117+
elements = { 512, CEIL_DIV(ne, 512), 1 };
7118+
} else {
7119+
elements = { ne, 1, 1 };
7120+
}
7121+
}
7122+
break;
70787123
default:
70797124
elements = { (uint32_t)ggml_nelements(src0), 1, 1 };
70807125
break;
@@ -7637,6 +7682,21 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const
76377682
}, dryrun);
76387683
}
76397684

7685+
static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
7686+
const uint32_t src0_type_size = ggml_type_size(src0->type);
7687+
const uint32_t src1_type_size = ggml_type_size(src1->type);
7688+
const uint32_t dst_type_size = ggml_type_size(dst->type);
7689+
7690+
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SET_ROWS, {
7691+
(uint32_t)ggml_nelements(src0),
7692+
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7693+
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
7694+
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7695+
0,
7696+
0.0f, 0.0f, 0,
7697+
}, dryrun);
7698+
}
7699+
76407700
static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
76417701
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
76427702
}
@@ -8957,6 +9017,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
89579017
case GGML_OP_CLAMP:
89589018
case GGML_OP_PAD:
89599019
case GGML_OP_CPY:
9020+
case GGML_OP_SET_ROWS:
89609021
case GGML_OP_CONT:
89619022
case GGML_OP_DUP:
89629023
case GGML_OP_SILU_BACK:
@@ -9023,6 +9084,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
90239084
case GGML_OP_CLAMP:
90249085
case GGML_OP_PAD:
90259086
case GGML_OP_CPY:
9087+
case GGML_OP_SET_ROWS:
90269088
case GGML_OP_CONT:
90279089
case GGML_OP_DUP:
90289090
case GGML_OP_SILU_BACK:
@@ -9131,6 +9193,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
91319193
case GGML_OP_DUP:
91329194
ggml_vk_cpy(ctx, compute_ctx, src0, node, dryrun);
91339195

9196+
break;
9197+
case GGML_OP_SET_ROWS:
9198+
ggml_vk_set_rows(ctx, compute_ctx, src0, src1, node, dryrun);
9199+
91349200
break;
91359201
case GGML_OP_SILU_BACK:
91369202
ggml_vk_silu_back(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -9346,6 +9412,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
93469412
case GGML_OP_CLAMP:
93479413
case GGML_OP_PAD:
93489414
case GGML_OP_CPY:
9415+
case GGML_OP_SET_ROWS:
93499416
case GGML_OP_CONT:
93509417
case GGML_OP_DUP:
93519418
case GGML_OP_SILU_BACK:
@@ -10411,9 +10478,20 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1041110478
} break;
1041210479
case GGML_OP_SET_ROWS:
1041310480
{
10414-
// TODO: add support
10415-
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
10416-
return false;
10481+
switch (op->type) {
10482+
case GGML_TYPE_F32:
10483+
case GGML_TYPE_F16:
10484+
case GGML_TYPE_BF16:
10485+
case GGML_TYPE_Q4_0:
10486+
case GGML_TYPE_Q4_1:
10487+
case GGML_TYPE_Q5_0:
10488+
case GGML_TYPE_Q5_1:
10489+
case GGML_TYPE_Q8_0:
10490+
case GGML_TYPE_IQ4_NL:
10491+
return true;
10492+
default:
10493+
return false;
10494+
}
1041710495
} break;
1041810496
case GGML_OP_CONT:
1041910497
case GGML_OP_CPY:
@@ -11028,6 +11106,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1102811106
} else {
1102911107
tensor_clone = ggml_cpy(ggml_ctx, src_clone[0], src_clone[1]);
1103011108
}
11109+
} else if (tensor->op == GGML_OP_SET_ROWS) {
11110+
tensor_clone = ggml_set_rows(ggml_ctx, src_clone[0], src_clone[1]);
1103111111
} else if (tensor->op == GGML_OP_CONT) {
1103211112
tensor_clone = ggml_cont_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
1103311113
} else if (tensor->op == GGML_OP_RESHAPE) {

ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,19 @@ spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bi
66
#endif // RTE16
77

88
#include "types.comp"
9-
#include "generic_unary_head.comp"
109

11-
#if defined(DATA_A_IQ4_NL)
12-
// 16 invocations needed for init_iq4nl_shmem
13-
layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in;
14-
#else
15-
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
16-
#endif
10+
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
1711

1812
layout (binding = 0) readonly buffer S {float data_s[];};
13+
14+
#if defined(SET_ROWS)
15+
#include "generic_binary_head.comp"
16+
layout (binding = 1) readonly buffer C {uvec2 data_i[];};
17+
layout (binding = 2) writeonly buffer Q {A_TYPE data_q[];};
18+
#else
19+
#include "generic_unary_head.comp"
1920
layout (binding = 1) writeonly buffer Q {A_TYPE data_q[];};
21+
#endif
2022

2123
#if defined(DATA_A_Q4_0)
2224
void quantize(uint dst_idx, uint src_idx)
@@ -221,15 +223,56 @@ void quantize(uint dst_idx, uint src_idx)
221223
}
222224
#endif
223225

226+
#if defined(DATA_A_F32) || defined(DATA_A_F16)
227+
void quantize(uint dst_idx, uint src_idx)
228+
{
229+
data_q[dst_idx] = A_TYPE(data_s[src_idx]);
230+
}
231+
#endif
232+
233+
#if defined(DATA_A_BF16)
234+
void quantize(uint dst_idx, uint src_idx)
235+
{
236+
data_q[dst_idx] = A_TYPE(fp32_to_bf16(data_s[src_idx]));
237+
}
238+
#endif
239+
240+
#if defined(SET_ROWS)
241+
224242
void main() {
225243
#ifdef NEEDS_INIT_IQ_SHMEM
226244
init_iq_shmem(gl_WorkGroupSize);
227-
if (gl_LocalInvocationIndex.x != 0) {
245+
#endif
246+
247+
const uint idx = (gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x) * QUANT_K;
248+
249+
if (idx >= p.ne) {
228250
return;
229251
}
252+
253+
uint i00, i01, i02, i03;
254+
get_indices(idx, i00, i01, i02, i03);
255+
256+
uint i12 = i03 % p.ne12;
257+
uint i11 = i02 % p.ne11;
258+
uint i10 = i01;
259+
260+
uint i1 = data_i[src1_idx(i10, i11, i12, 0) + get_boffset()].x;
261+
262+
uint src0_idx = src0_idx(i00, i01, i02, i03) + get_aoffset();
263+
uint dst_idx = dst_idx(i00 / QUANT_K, i1, i02, i03) + get_doffset();
264+
265+
quantize(dst_idx, src0_idx);
266+
}
267+
268+
#else
269+
270+
void main() {
271+
#ifdef NEEDS_INIT_IQ_SHMEM
272+
init_iq_shmem(gl_WorkGroupSize);
230273
#endif
231274

232-
const uint idx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * QUANT_K;
275+
const uint idx = (gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x) * QUANT_K;
233276

234277
if (idx >= p.ne) {
235278
return;
@@ -240,3 +283,5 @@ void main() {
240283

241284
quantize(dst_idx, src_idx);
242285
}
286+
287+
#endif

0 commit comments

Comments
 (0)