Skip to content

Commit 2749ef5

Browse files
committed
vulkan: Additional type support for unary, binary, and copy
Support f16->f32 copy. Support f16->f16 and f32->f32 unary ops. Support all combinations of f16/f32 for src0/src1/dst for add/sub/mul/div.
1 parent d24d592 commit 2749ef5

File tree

5 files changed

+161
-87
lines changed

5 files changed

+161
-87
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 115 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -340,11 +340,17 @@ struct vk_device_struct {
340340
vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
341341
vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
342342
vk_pipeline pipeline_acc_f32;
343-
vk_pipeline pipeline_add_f32, pipeline_add_f32_norepeat;
344-
vk_pipeline pipeline_add_f16_f32_f16, pipeline_add_f16_f32_f16_norepeat;
345-
vk_pipeline pipeline_sub_f32, pipeline_sub_f32_norepeat;
346-
vk_pipeline pipeline_mul_f32, pipeline_mul_f32_norepeat;
347-
vk_pipeline pipeline_div_f32, pipeline_div_f32_norepeat;
343+
344+
// [src0 0=fp32,1=fp16][src1 0=fp32,1=fp16][dst 0=fp32,1=fp16]
345+
vk_pipeline pipeline_add[2][2][2];
346+
vk_pipeline pipeline_add_norepeat[2][2][2];
347+
vk_pipeline pipeline_sub[2][2][2];
348+
vk_pipeline pipeline_sub_norepeat[2][2][2];
349+
vk_pipeline pipeline_mul[2][2][2];
350+
vk_pipeline pipeline_mul_norepeat[2][2][2];
351+
vk_pipeline pipeline_div[2][2][2];
352+
vk_pipeline pipeline_div_norepeat[2][2][2];
353+
348354
vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
349355
vk_pipeline pipeline_upscale_f32;
350356
vk_pipeline pipeline_scale_f32;
@@ -354,23 +360,26 @@ struct vk_device_struct {
354360
vk_pipeline pipeline_clamp_f32;
355361
vk_pipeline pipeline_pad_f32;
356362
vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
357-
vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f32_bf16;
358-
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f32_bf16;
363+
vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16;
364+
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16;
359365
vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
360366
vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
361367
vk_pipeline pipeline_norm_f32;
362368
vk_pipeline pipeline_group_norm_f32;
363369
vk_pipeline pipeline_rms_norm_f32;
364370
vk_pipeline pipeline_rms_norm_back_f32;
365371
vk_pipeline pipeline_l2_norm_f32;
366-
vk_pipeline pipeline_gelu_f32;
367-
vk_pipeline pipeline_gelu_quick_f32;
368-
vk_pipeline pipeline_silu_f32;
369-
vk_pipeline pipeline_silu_back_f32;
370-
vk_pipeline pipeline_relu_f32;
372+
373+
// [src/dst 0=fp32,1=fp16]
374+
vk_pipeline pipeline_gelu[2];
375+
vk_pipeline pipeline_gelu_quick[2];
376+
vk_pipeline pipeline_silu[2];
377+
vk_pipeline pipeline_relu[2];
378+
vk_pipeline pipeline_tanh[2];
379+
vk_pipeline pipeline_sigmoid[2];
380+
371381
vk_pipeline pipeline_leaky_relu_f32;
372-
vk_pipeline pipeline_tanh_f32;
373-
vk_pipeline pipeline_sigmoid_f32;
382+
vk_pipeline pipeline_silu_back_f32;
374383
vk_pipeline pipeline_diag_mask_inf_f32;
375384
vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
376385
vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
@@ -2488,11 +2497,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
24882497
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
24892498
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
24902499
ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2500+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f32, "cpy_f16_f32", cpy_f16_f32_len, cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
24912501
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_bf16,"cpy_f32_bf16",cpy_f32_bf16_len,cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
24922502

24932503
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
24942504
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
24952505
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2506+
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f32, "contig_cpy_f16_f32", contig_cpy_f16_f32_len, contig_cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
24962507
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
24972508

24982509
if (device->float_controls_rte_fp16) {
@@ -2518,19 +2529,31 @@ static void ggml_vk_load_shaders(vk_device& device) {
25182529
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q8_0], "cpy_q8_0_f32", cpy_q8_0_f32_len, cpy_q8_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
25192530
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_IQ4_NL], "cpy_iq4_nl_f32", cpy_iq4_nl_f32_len, cpy_iq4_nl_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
25202531

2521-
ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
2522-
ggml_vk_create_pipeline(device, device->pipeline_add_f32_norepeat, "add_f32_norepeat", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
2523-
ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16, "add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
2524-
ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16_norepeat, "add_f16_f32_f16_norepeat", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
2532+
auto get_suffix = [](bool src0_f16, bool src1_f16, bool dst_f16) {
2533+
std::string s;
2534+
s += std::string(src0_f16 ? "_f16" : "_f32");
2535+
s += std::string(src1_f16 ? "_f16" : "_f32");
2536+
s += std::string(dst_f16 ? "_f16" : "_f32");
2537+
return s;
2538+
};
25252539

2526-
ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
2540+
#define CREATE_BINARY(name, namemod, spec) \
2541+
for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \
2542+
ggml_vk_create_pipeline(device, device->pipeline_ ## name ## namemod[s0][s1][d], \
2543+
#name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d], name ## _data[s0][s1][d], \
2544+
"main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);
2545+
2546+
CREATE_BINARY(add, , {0})
2547+
CREATE_BINARY(add, _norepeat, {1})
2548+
CREATE_BINARY(sub, , {0})
2549+
CREATE_BINARY(sub, _norepeat, {1})
2550+
CREATE_BINARY(mul, , {0})
2551+
CREATE_BINARY(mul, _norepeat, {1})
2552+
CREATE_BINARY(div, , {0})
2553+
CREATE_BINARY(div, _norepeat, {1})
2554+
#undef CREATE_BINARY
25272555

2528-
ggml_vk_create_pipeline(device, device->pipeline_sub_f32, "sub_f32", sub_f32_len, sub_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
2529-
ggml_vk_create_pipeline(device, device->pipeline_sub_f32_norepeat, "sub_f32_norepeat", sub_f32_len, sub_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
2530-
ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
2531-
ggml_vk_create_pipeline(device, device->pipeline_mul_f32_norepeat, "mul_f32_norepeat", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
2532-
ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
2533-
ggml_vk_create_pipeline(device, device->pipeline_div_f32_norepeat, "div_f32_norepeat", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
2556+
ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
25342557

25352558
ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
25362559
ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
@@ -2551,14 +2574,20 @@ static void ggml_vk_load_shaders(vk_device& device) {
25512574
ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
25522575
ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
25532576

2554-
ggml_vk_create_pipeline(device, device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2555-
ggml_vk_create_pipeline(device, device->pipeline_gelu_quick_f32, "gelu_quick_f32", gelu_quick_f32_len, gelu_quick_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2556-
ggml_vk_create_pipeline(device, device->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2557-
ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2558-
ggml_vk_create_pipeline(device, device->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2577+
#define CREATE_UNARY(name) \
2578+
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
2579+
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2580+
2581+
CREATE_UNARY(gelu)
2582+
CREATE_UNARY(gelu_quick)
2583+
CREATE_UNARY(silu)
2584+
CREATE_UNARY(relu)
2585+
CREATE_UNARY(tanh)
2586+
CREATE_UNARY(sigmoid)
2587+
#undef CREATE_UNARY
2588+
25592589
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2560-
ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2561-
ggml_vk_create_pipeline(device, device->pipeline_sigmoid_f32, "sigmoid_f32", sigmoid_f32_len, sigmoid_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2590+
ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
25622591

25632592
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
25642593

@@ -4481,6 +4510,13 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
44814510
return ctx->device->pipeline_cpy_f16_f16;
44824511
}
44834512
}
4513+
if (src->type == GGML_TYPE_F16 && to == GGML_TYPE_F32) {
4514+
if (contig) {
4515+
return ctx->device->pipeline_contig_cpy_f16_f32;
4516+
} else {
4517+
return ctx->device->pipeline_cpy_f16_f32;
4518+
}
4519+
}
44844520
if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_BF16) {
44854521
if (contig) {
44864522
return ctx->device->pipeline_contig_cpy_f32_bf16;
@@ -5871,26 +5907,37 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
58715907
}
58725908
return nullptr;
58735909
case GGML_OP_ADD:
5874-
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5875-
return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f32_norepeat : ctx->device->pipeline_add_f32;
5910+
case GGML_OP_SUB:
5911+
case GGML_OP_MUL:
5912+
case GGML_OP_DIV:
5913+
if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) ||
5914+
(src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) ||
5915+
(dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16)) {
5916+
return nullptr;
58765917
}
5877-
if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
5878-
return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f16_f32_f16_norepeat : ctx->device->pipeline_add_f16_f32_f16;
5918+
switch (op) {
5919+
case GGML_OP_ADD:
5920+
{
5921+
auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add;
5922+
return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
58795923
}
5880-
return nullptr;
5881-
case GGML_OP_SUB:
5882-
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5883-
return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_sub_f32_norepeat : ctx->device->pipeline_sub_f32;
5924+
case GGML_OP_SUB:
5925+
{
5926+
auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_sub_norepeat : ctx->device->pipeline_sub;
5927+
return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
58845928
}
5885-
return nullptr;
5886-
case GGML_OP_MUL:
5887-
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5888-
return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_f32_norepeat : ctx->device->pipeline_mul_f32;
5929+
case GGML_OP_MUL:
5930+
{
5931+
auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_norepeat : ctx->device->pipeline_mul;
5932+
return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
58895933
}
5890-
return nullptr;
5891-
case GGML_OP_DIV:
5892-
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5893-
return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_div_f32_norepeat : ctx->device->pipeline_div_f32;
5934+
case GGML_OP_DIV:
5935+
{
5936+
auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_div_norepeat : ctx->device->pipeline_div;
5937+
return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
5938+
}
5939+
default:
5940+
break;
58945941
}
58955942
return nullptr;
58965943
case GGML_OP_CONCAT:
@@ -5984,37 +6031,25 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
59846031
}
59856032
return nullptr;
59866033
case GGML_OP_UNARY:
6034+
if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) ||
6035+
(dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) ||
6036+
(src0->type != dst->type)) {
6037+
return nullptr;
6038+
}
6039+
59876040
switch (ggml_get_unary_op(dst)) {
59886041
case GGML_UNARY_OP_SILU:
5989-
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5990-
return ctx->device->pipeline_silu_f32;
5991-
}
5992-
break;
6042+
return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16];
59936043
case GGML_UNARY_OP_GELU:
5994-
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5995-
return ctx->device->pipeline_gelu_f32;
5996-
}
5997-
break;
6044+
return ctx->device->pipeline_gelu[dst->type == GGML_TYPE_F16];
59986045
case GGML_UNARY_OP_GELU_QUICK:
5999-
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6000-
return ctx->device->pipeline_gelu_quick_f32;
6001-
}
6002-
break;
6046+
return ctx->device->pipeline_gelu_quick[dst->type == GGML_TYPE_F16];
60036047
case GGML_UNARY_OP_RELU:
6004-
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6005-
return ctx->device->pipeline_relu_f32;
6006-
}
6007-
break;
6048+
return ctx->device->pipeline_relu[dst->type == GGML_TYPE_F16];
60086049
case GGML_UNARY_OP_TANH:
6009-
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6010-
return ctx->device->pipeline_tanh_f32;
6011-
}
6012-
break;
6050+
return ctx->device->pipeline_tanh[dst->type == GGML_TYPE_F16];
60136051
case GGML_UNARY_OP_SIGMOID:
6014-
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6015-
return ctx->device->pipeline_sigmoid_f32;
6016-
}
6017-
break;
6052+
return ctx->device->pipeline_sigmoid[dst->type == GGML_TYPE_F16];
60186053
default:
60196054
break;
60206055
}
@@ -9358,7 +9393,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
93589393
case GGML_UNARY_OP_RELU:
93599394
case GGML_UNARY_OP_TANH:
93609395
case GGML_UNARY_OP_SIGMOID:
9361-
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
9396+
return ggml_is_contiguous(op->src[0]) &&
9397+
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
9398+
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
9399+
(op->src[0]->type == op->type);
93629400
default:
93639401
return false;
93649402
}
@@ -9538,6 +9576,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
95389576
}
95399577
if (src1_type == GGML_TYPE_F32) {
95409578
switch (src0_type) {
9579+
case GGML_TYPE_F16:
95419580
case GGML_TYPE_Q4_0:
95429581
case GGML_TYPE_Q4_1:
95439582
case GGML_TYPE_Q5_0:
@@ -9576,6 +9615,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
95769615
case GGML_OP_SUB:
95779616
case GGML_OP_MUL:
95789617
case GGML_OP_DIV:
9618+
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
9619+
(op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) &&
9620+
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
95799621
case GGML_OP_SILU_BACK:
95809622
case GGML_OP_RMS_NORM_BACK:
95819623
case GGML_OP_SQR:

ggml/src/ggml-vulkan/vulkan-shaders/relu.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,5 @@ void main() {
1717
return;
1818
}
1919

20-
data_d[i] = max(float(data_a[i]), 0);
20+
data_d[i] = D_TYPE(max(float(data_a[i]), 0));
2121
}

ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,5 @@ void main() {
1616
if (i >= p.KX) {
1717
return;
1818
}
19-
data_d[i] = D_TYPE(1. / (1 + exp(-1. *data_a[i])));
19+
data_d[i] = D_TYPE(1. / (1 + exp(-1. * float(data_a[i]))));
2020
}

ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,5 @@ void main() {
1616
if (i >= p.KX) {
1717
return;
1818
}
19-
data_d[i] = D_TYPE(1. - 2. / (exp(2.*data_a[i]) + 1.));
19+
data_d[i] = D_TYPE(1. - 2. / (exp(2.*float(data_a[i])) + 1.));
2020
}

0 commit comments

Comments
 (0)