@@ -490,6 +490,7 @@ struct vk_device_struct {
490490 vk_pipeline pipeline_l2_norm_f32;
491491
492492 // [src/dst 0=fp32,1=fp16]
493+ vk_pipeline pipeline_exp[2];
493494 vk_pipeline pipeline_gelu[2];
494495 vk_pipeline pipeline_gelu_erf[2];
495496 vk_pipeline pipeline_gelu_quick[2];
@@ -3062,6 +3063,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
30623063 ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
30633064 ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
30643065
3066+ CREATE_UNARY(exp)
30653067 CREATE_UNARY(gelu)
30663068 CREATE_UNARY(gelu_erf)
30673069 CREATE_UNARY(gelu_quick)
@@ -7097,6 +7099,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
70977099 }
70987100
70997101 switch (ggml_get_unary_op(dst)) {
7102+ case GGML_UNARY_OP_EXP:
7103+ return ctx->device->pipeline_exp[dst->type == GGML_TYPE_F16];
71007104 case GGML_UNARY_OP_SILU:
71017105 return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16];
71027106 case GGML_UNARY_OP_GELU:
@@ -9702,6 +9706,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
97029706 return false;
97039707 case GGML_OP_UNARY:
97049708 switch (ggml_get_unary_op(node)) {
9709+ case GGML_UNARY_OP_EXP:
97059710 case GGML_UNARY_OP_SILU:
97069711 case GGML_UNARY_OP_GELU:
97079712 case GGML_UNARY_OP_GELU_ERF:
@@ -9979,6 +9984,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
99799984 break;
99809985 case GGML_OP_UNARY:
99819986 switch (ggml_get_unary_op(node)) {
9987+ case GGML_UNARY_OP_EXP:
99829988 case GGML_UNARY_OP_SILU:
99839989 case GGML_UNARY_OP_GELU:
99849990 case GGML_UNARY_OP_GELU_ERF:
@@ -10215,6 +10221,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
1021510221 break;
1021610222 case GGML_OP_UNARY:
1021710223 switch (ggml_get_unary_op(tensor)) {
10224+ case GGML_UNARY_OP_EXP:
1021810225 case GGML_UNARY_OP_SILU:
1021910226 case GGML_UNARY_OP_GELU:
1022010227 case GGML_UNARY_OP_GELU_ERF:
@@ -11125,6 +11132,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1112511132 switch (op->op) {
1112611133 case GGML_OP_UNARY:
1112711134 switch (ggml_get_unary_op(op)) {
11135+ case GGML_UNARY_OP_EXP:
1112811136 case GGML_UNARY_OP_GELU:
1112911137 case GGML_UNARY_OP_GELU_ERF:
1113011138 case GGML_UNARY_OP_GELU_QUICK:
@@ -11924,6 +11932,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1192411932 }
1192511933 } else if (tensor->op == GGML_OP_UNARY) {
1192611934 switch (ggml_get_unary_op(tensor)) {
11935+ case GGML_UNARY_OP_EXP:
11936+ tensor_clone = ggml_exp(ggml_ctx, src_clone[0]);
11937+ break;
1192711938 case GGML_UNARY_OP_SILU:
1192811939 tensor_clone = ggml_silu(ggml_ctx, src_clone[0]);
1192911940 break;
0 commit comments