@@ -490,6 +490,7 @@ struct vk_device_struct {
490
490
vk_pipeline pipeline_l2_norm_f32;
491
491
492
492
// [src/dst 0=fp32,1=fp16]
493
+ vk_pipeline pipeline_exp[2];
493
494
vk_pipeline pipeline_gelu[2];
494
495
vk_pipeline pipeline_gelu_erf[2];
495
496
vk_pipeline pipeline_gelu_quick[2];
@@ -3066,6 +3067,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
3066
3067
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
3067
3068
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
3068
3069
3070
+ CREATE_UNARY(exp)
3069
3071
CREATE_UNARY(gelu)
3070
3072
CREATE_UNARY(gelu_erf)
3071
3073
CREATE_UNARY(gelu_quick)
@@ -7133,6 +7135,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
7133
7135
}
7134
7136
7135
7137
switch (ggml_get_unary_op(dst)) {
7138
+ case GGML_UNARY_OP_EXP:
7139
+ return ctx->device->pipeline_exp[dst->type == GGML_TYPE_F16];
7136
7140
case GGML_UNARY_OP_SILU:
7137
7141
return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16];
7138
7142
case GGML_UNARY_OP_GELU:
@@ -9738,6 +9742,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9738
9742
return false;
9739
9743
case GGML_OP_UNARY:
9740
9744
switch (ggml_get_unary_op(node)) {
9745
+ case GGML_UNARY_OP_EXP:
9741
9746
case GGML_UNARY_OP_SILU:
9742
9747
case GGML_UNARY_OP_GELU:
9743
9748
case GGML_UNARY_OP_GELU_ERF:
@@ -10015,6 +10020,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
10015
10020
break;
10016
10021
case GGML_OP_UNARY:
10017
10022
switch (ggml_get_unary_op(node)) {
10023
+ case GGML_UNARY_OP_EXP:
10018
10024
case GGML_UNARY_OP_SILU:
10019
10025
case GGML_UNARY_OP_GELU:
10020
10026
case GGML_UNARY_OP_GELU_ERF:
@@ -10251,6 +10257,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
10251
10257
break;
10252
10258
case GGML_OP_UNARY:
10253
10259
switch (ggml_get_unary_op(tensor)) {
10260
+ case GGML_UNARY_OP_EXP:
10254
10261
case GGML_UNARY_OP_SILU:
10255
10262
case GGML_UNARY_OP_GELU:
10256
10263
case GGML_UNARY_OP_GELU_ERF:
@@ -11166,6 +11173,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
11166
11173
switch (op->op) {
11167
11174
case GGML_OP_UNARY:
11168
11175
switch (ggml_get_unary_op(op)) {
11176
+ case GGML_UNARY_OP_EXP:
11169
11177
case GGML_UNARY_OP_GELU:
11170
11178
case GGML_UNARY_OP_GELU_ERF:
11171
11179
case GGML_UNARY_OP_GELU_QUICK:
@@ -11965,6 +11973,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
11965
11973
}
11966
11974
} else if (tensor->op == GGML_OP_UNARY) {
11967
11975
switch (ggml_get_unary_op(tensor)) {
11976
+ case GGML_UNARY_OP_EXP:
11977
+ tensor_clone = ggml_exp(ggml_ctx, src_clone[0]);
11978
+ break;
11968
11979
case GGML_UNARY_OP_SILU:
11969
11980
tensor_clone = ggml_silu(ggml_ctx, src_clone[0]);
11970
11981
break;
0 commit comments