@@ -249,6 +249,7 @@ struct vk_device_struct {
249249 vk_pipeline pipeline_relu_f32;
250250 vk_pipeline pipeline_leaky_relu_f32;
251251 vk_pipeline pipeline_tanh_f32;
252+ vk_pipeline pipeline_sigmoid_f32;
252253 vk_pipeline pipeline_diag_mask_inf_f32;
253254 vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
254255 vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
@@ -2189,6 +2190,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
21892190 ggml_vk_create_pipeline (device, device->pipeline_relu_f32 , " relu_f32" , relu_f32_len, relu_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {512 , 1 , 1 }, {}, 1 );
21902191 ggml_vk_create_pipeline (device, device->pipeline_leaky_relu_f32 , " leaky_relu_f32" , leaky_relu_f32_len, leaky_relu_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {512 , 1 , 1 }, {}, 1 );
21912192 ggml_vk_create_pipeline (device, device->pipeline_tanh_f32 , " tanh_f32" , tanh_f32_len, tanh_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {512 , 1 , 1 }, {}, 1 );
2193+ ggml_vk_create_pipeline (device, device->pipeline_sigmoid_f32 , " sigmoid_f32" , sigmoid_f32_len, sigmoid_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {512 , 1 , 1 }, {}, 1 );
21922194
21932195 ggml_vk_create_pipeline (device, device->pipeline_diag_mask_inf_f32 , " diag_mask_inf_f32" , diag_mask_inf_f32_len, diag_mask_inf_f32_data, " main" , 2 , sizeof (vk_op_diag_mask_push_constants), {1 , 512 , 1 }, {}, 1 , true );
21942196
@@ -5342,6 +5344,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
53425344 return ctx->device ->pipeline_tanh_f32 ;
53435345 }
53445346 break ;
5347+ case GGML_UNARY_OP_SIGMOID:
5348+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5349+ return ctx->device ->pipeline_sigmoid_f32 ;
5350+ }
5351+ break ;
53455352 default :
53465353 break ;
53475354 }
@@ -7335,6 +7342,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
73357342 case GGML_UNARY_OP_GELU_QUICK:
73367343 case GGML_UNARY_OP_RELU:
73377344 case GGML_UNARY_OP_TANH:
7345+ case GGML_UNARY_OP_SIGMOID:
73387346 break ;
73397347 default :
73407348 return false ;
@@ -7551,6 +7559,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
75517559 case GGML_UNARY_OP_GELU_QUICK:
75527560 case GGML_UNARY_OP_RELU:
75537561 case GGML_UNARY_OP_TANH:
7562+ case GGML_UNARY_OP_SIGMOID:
75547563 ggml_vk_unary (ctx, compute_ctx, src0, node, dryrun);
75557564 break ;
75567565 default :
@@ -7738,6 +7747,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
77387747 case GGML_UNARY_OP_GELU_QUICK:
77397748 case GGML_UNARY_OP_RELU:
77407749 case GGML_UNARY_OP_TANH:
7750+ case GGML_UNARY_OP_SIGMOID:
77417751 buf = tensor->buffer ;
77427752 break ;
77437753 default :
@@ -8439,6 +8449,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
84398449 case GGML_UNARY_OP_SILU:
84408450 case GGML_UNARY_OP_RELU:
84418451 case GGML_UNARY_OP_TANH:
8452+ case GGML_UNARY_OP_SIGMOID:
84428453 return ggml_is_contiguous (op->src [0 ]);
84438454 default :
84448455 return false ;
@@ -9105,6 +9116,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
91059116 case GGML_UNARY_OP_TANH:
91069117 tensor_clone = ggml_tanh (ggml_ctx, src_clone[0 ]);
91079118 break ;
9119+ case GGML_UNARY_OP_SIGMOID:
9120+ tensor_clone = ggml_sigmoid (ggml_ctx, src_clone[0 ]);
9121+ break ;
91089122 default :
91099123 std::cerr << " Missing vk_check_results OP: " << ggml_op_name (tensor->op ) << std::endl;
91109124 GGML_ABORT (" fatal error" );
0 commit comments