@@ -490,6 +490,7 @@ struct vk_device_struct {
490490 vk_pipeline pipeline_l2_norm_f32;
491491
492492 // [src/dst 0=fp32,1=fp16]
493+ vk_pipeline pipeline_exp[2];
493494 vk_pipeline pipeline_gelu[2];
494495 vk_pipeline pipeline_gelu_erf[2];
495496 vk_pipeline pipeline_gelu_quick[2];
@@ -3066,6 +3067,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
30663067 ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
30673068 ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
30683069
3070+ CREATE_UNARY(exp)
30693071 CREATE_UNARY(gelu)
30703072 CREATE_UNARY(gelu_erf)
30713073 CREATE_UNARY(gelu_quick)
@@ -7133,6 +7135,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
71337135 }
71347136
71357137 switch (ggml_get_unary_op(dst)) {
7138+ case GGML_UNARY_OP_EXP:
7139+ return ctx->device->pipeline_exp[dst->type == GGML_TYPE_F16];
71367140 case GGML_UNARY_OP_SILU:
71377141 return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16];
71387142 case GGML_UNARY_OP_GELU:
@@ -9738,6 +9742,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
97389742 return false;
97399743 case GGML_OP_UNARY:
97409744 switch (ggml_get_unary_op(node)) {
9745+ case GGML_UNARY_OP_EXP:
97419746 case GGML_UNARY_OP_SILU:
97429747 case GGML_UNARY_OP_GELU:
97439748 case GGML_UNARY_OP_GELU_ERF:
@@ -10015,6 +10020,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1001510020 break;
1001610021 case GGML_OP_UNARY:
1001710022 switch (ggml_get_unary_op(node)) {
10023+ case GGML_UNARY_OP_EXP:
1001810024 case GGML_UNARY_OP_SILU:
1001910025 case GGML_UNARY_OP_GELU:
1002010026 case GGML_UNARY_OP_GELU_ERF:
@@ -10251,6 +10257,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
1025110257 break;
1025210258 case GGML_OP_UNARY:
1025310259 switch (ggml_get_unary_op(tensor)) {
10260+ case GGML_UNARY_OP_EXP:
1025410261 case GGML_UNARY_OP_SILU:
1025510262 case GGML_UNARY_OP_GELU:
1025610263 case GGML_UNARY_OP_GELU_ERF:
@@ -11166,6 +11173,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1116611173 switch (op->op) {
1116711174 case GGML_OP_UNARY:
1116811175 switch (ggml_get_unary_op(op)) {
11176+ case GGML_UNARY_OP_EXP:
1116911177 case GGML_UNARY_OP_GELU:
1117011178 case GGML_UNARY_OP_GELU_ERF:
1117111179 case GGML_UNARY_OP_GELU_QUICK:
@@ -11965,6 +11973,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1196511973 }
1196611974 } else if (tensor->op == GGML_OP_UNARY) {
1196711975 switch (ggml_get_unary_op(tensor)) {
11976+ case GGML_UNARY_OP_EXP:
11977+ tensor_clone = ggml_exp(ggml_ctx, src_clone[0]);
11978+ break;
1196811979 case GGML_UNARY_OP_SILU:
1196911980 tensor_clone = ggml_silu(ggml_ctx, src_clone[0]);
1197011981 break;
0 commit comments