@@ -218,6 +218,7 @@ struct vk_device_struct {
218218 vk_pipeline pipeline_tanh_f32;
219219 vk_pipeline pipeline_diag_mask_inf_f32;
220220 vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
221+ vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
221222 vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16;
222223 vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
223224 vk_pipeline pipeline_argsort_f32;
@@ -1498,7 +1499,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
14981499 ggml_vk_create_pipeline (device, device->pipeline_diag_mask_inf_f32 , " diag_mask_inf_f32" , diag_mask_inf_f32_len, diag_mask_inf_f32_data, " main" , 2 , sizeof (vk_op_diag_mask_push_constants), {512 , 1 , 1 }, {}, 1 );
14991500
15001501 ggml_vk_create_pipeline (device, device->pipeline_soft_max_f32 , " soft_max_f32" , soft_max_f32_len, soft_max_f32_data, " main" , 3 , sizeof (vk_op_soft_max_push_constants), {1 , 1 , 1 }, { device->subgroup_size }, 1 );
1502+ ggml_vk_create_pipeline (device, device->pipeline_soft_max_f32_wg512 , " soft_max_f32_wg512" , soft_max_f32_len, soft_max_f32_data, " main" , 3 , sizeof (vk_op_soft_max_push_constants), {1 , 1 , 1 }, { 512 }, 1 );
15011503 ggml_vk_create_pipeline (device, device->pipeline_soft_max_f32_f16 , " soft_max_f32_f16" , soft_max_f32_f16_len, soft_max_f32_f16_data, " main" , 3 , sizeof (vk_op_soft_max_push_constants), {1 , 1 , 1 }, { device->subgroup_size }, 1 );
1504+ ggml_vk_create_pipeline (device, device->pipeline_soft_max_f32_f16_wg512 , " soft_max_f32_f16_wg512" , soft_max_f32_f16_len, soft_max_f32_f16_data, " main" , 3 , sizeof (vk_op_soft_max_push_constants), {1 , 1 , 1 }, { 512 }, 1 );
15021505
15031506 ggml_vk_create_pipeline (device, device->pipeline_rope_norm_f32 , " rope_norm_f32" , rope_norm_f32_len, rope_norm_f32_data, " main" , 4 , sizeof (vk_op_rope_push_constants), {1 , 512 , 1 }, {}, 1 );
15041507 ggml_vk_create_pipeline (device, device->pipeline_rope_norm_f16 , " rope_norm_f16" , rope_norm_f16_len, rope_norm_f16_data, " main" , 4 , sizeof (vk_op_rope_push_constants), {1 , 512 , 1 }, {}, 1 );
@@ -3933,10 +3936,10 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
39333936 GGML_ASSERT (!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
39343937
39353938 if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
3936- return ctx->device ->pipeline_soft_max_f32 ;
3939+ return src0-> ne [ 0 ] > 1024 ? ctx-> device -> pipeline_soft_max_f32_wg512 : ctx->device ->pipeline_soft_max_f32 ;
39373940 }
39383941 if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
3939- return ctx->device ->pipeline_soft_max_f32_f16 ;
3942+ return src0-> ne [ 0 ] > 1024 ? ctx-> device -> pipeline_soft_max_f32_f16_wg512 : ctx->device ->pipeline_soft_max_f32_f16 ;
39403943 }
39413944 return nullptr ;
39423945 case GGML_OP_ROPE:
0 commit comments