@@ -247,6 +247,7 @@ struct vk_device_struct {
247247 vk_pipeline pipeline_relu_f32;
248248 vk_pipeline pipeline_leaky_relu_f32;
249249 vk_pipeline pipeline_tanh_f32;
250+ vk_pipeline pipeline_sigmoid_f32;
250251 vk_pipeline pipeline_diag_mask_inf_f32;
251252 vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
252253 vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
@@ -2183,6 +2184,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
21832184 ggml_vk_create_pipeline (device, device->pipeline_relu_f32 , " relu_f32" , relu_f32_len, relu_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {512 , 1 , 1 }, {}, 1 );
21842185 ggml_vk_create_pipeline (device, device->pipeline_leaky_relu_f32 , " leaky_relu_f32" , leaky_relu_f32_len, leaky_relu_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {512 , 1 , 1 }, {}, 1 );
21852186 ggml_vk_create_pipeline (device, device->pipeline_tanh_f32 , " tanh_f32" , tanh_f32_len, tanh_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {512 , 1 , 1 }, {}, 1 );
2187+ ggml_vk_create_pipeline (device, device->pipeline_sigmoid_f32 , " sigmoid_f32" , sigmoid_f32_len, sigmoid_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {512 , 1 , 1 }, {}, 1 );
21862188
21872189 ggml_vk_create_pipeline (device, device->pipeline_diag_mask_inf_f32 , " diag_mask_inf_f32" , diag_mask_inf_f32_len, diag_mask_inf_f32_data, " main" , 2 , sizeof (vk_op_diag_mask_push_constants), {1 , 512 , 1 }, {}, 1 , true );
21882190
@@ -5325,6 +5327,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
53255327 return ctx->device ->pipeline_tanh_f32 ;
53265328 }
53275329 break ;
5330+ case GGML_UNARY_OP_SIGMOID:
5331+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5332+ return ctx->device ->pipeline_sigmoid_f32 ;
5333+ }
5334+ break ;
53285335 default :
53295336 break ;
53305337 }
@@ -7295,6 +7302,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
72957302 case GGML_UNARY_OP_GELU_QUICK:
72967303 case GGML_UNARY_OP_RELU:
72977304 case GGML_UNARY_OP_TANH:
7305+ case GGML_UNARY_OP_SIGMOID:
72987306 break ;
72997307 default :
73007308 return false ;
@@ -7495,6 +7503,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
74957503 case GGML_UNARY_OP_GELU_QUICK:
74967504 case GGML_UNARY_OP_RELU:
74977505 case GGML_UNARY_OP_TANH:
7506+ case GGML_UNARY_OP_SIGMOID:
74987507 ggml_vk_unary (ctx, compute_ctx, src0, node, dryrun);
74997508 break ;
75007509 default :
@@ -7670,6 +7679,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
76707679 case GGML_UNARY_OP_GELU_QUICK:
76717680 case GGML_UNARY_OP_RELU:
76727681 case GGML_UNARY_OP_TANH:
7682+ case GGML_UNARY_OP_SIGMOID:
76737683 buf = tensor->buffer ;
76747684 break ;
76757685 default :
@@ -8371,6 +8381,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
83718381 case GGML_UNARY_OP_SILU:
83728382 case GGML_UNARY_OP_RELU:
83738383 case GGML_UNARY_OP_TANH:
8384+ case GGML_UNARY_OP_SIGMOID:
83748385 return ggml_is_contiguous (op->src [0 ]);
83758386 default :
83768387 return false ;
@@ -9018,6 +9029,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
90189029 case GGML_UNARY_OP_TANH:
90199030 tensor_clone = ggml_tanh (ggml_ctx, src_clone[0 ]);
90209031 break ;
9032+ case GGML_UNARY_OP_SIGMOID:
9033+ tensor_clone = ggml_sigmoid (ggml_ctx, src_clone[0 ]);
9034+ break ;
90219035 default :
90229036 std::cerr << " Missing vk_check_results OP: " << ggml_op_name (tensor->op ) << std::endl;
90239037 GGML_ABORT (" fatal error" );
0 commit comments