@@ -143,6 +143,7 @@ struct ggml_backend_opencl_context {
143143 cl_kernel kernel_rms_norm;
144144 cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8;
145145 cl_kernel kernel_soft_max, kernel_soft_max_4;
146+ cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16;
146147 cl_kernel kernel_get_rows_f32, kernel_get_rows_f16, kernel_get_rows_q4_0;
147148 cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16;
148149 cl_kernel kernel_cpy_f16_f16, kernel_cpy_f16_f32, kernel_cpy_f32_f16, kernel_cpy_f32_f32;
@@ -614,6 +615,8 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
614615 CL_CHECK ((backend_ctx->kernel_diag_mask_inf_8 = clCreateKernel (backend_ctx->program , " kernel_diag_mask_inf_8" , &err), err));
615616 CL_CHECK ((backend_ctx->kernel_soft_max = clCreateKernel (backend_ctx->program , " kernel_soft_max" , &err), err));
616617 CL_CHECK ((backend_ctx->kernel_soft_max_4 = clCreateKernel (backend_ctx->program , " kernel_soft_max_4" , &err), err));
618+ CL_CHECK ((backend_ctx->kernel_soft_max_f16 = clCreateKernel (backend_ctx->program , " kernel_soft_max_f16" , &err), err));
619+ CL_CHECK ((backend_ctx->kernel_soft_max_4_f16 = clCreateKernel (backend_ctx->program , " kernel_soft_max_4_f16" , &err), err));
617620 CL_CHECK ((backend_ctx->kernel_rope_norm_f32 = clCreateKernel (backend_ctx->program , " kernel_rope_norm_f32" , &err), err));
618621 CL_CHECK ((backend_ctx->kernel_rope_norm_f16 = clCreateKernel (backend_ctx->program , " kernel_rope_norm_f16" , &err), err));
619622 CL_CHECK ((backend_ctx->kernel_rope_neox_f32 = clCreateKernel (backend_ctx->program , " kernel_rope_neox_f32" , &err), err));
@@ -3674,6 +3677,8 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
36743677 const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
36753678 const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
36763679
3680+ const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
3681+
36773682 // Local size must be wave size. Each workgroup is a wave, working on a row,
36783683 // where a row corresponds to leading dimension.
36793684 int nth = MIN (32 , ne00);
@@ -3691,9 +3696,17 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
36913696 cl_kernel kernel;
36923697
36933698 if (ne00%4 == 0 ) {
3694- kernel = backend_ctx->kernel_soft_max_4 ;
3699+ if (use_f16) {
3700+ kernel = backend_ctx->kernel_soft_max_4_f16 ;
3701+ } else {
3702+ kernel = backend_ctx->kernel_soft_max_4 ;
3703+ }
36953704 } else {
3696- kernel = backend_ctx->kernel_soft_max ;
3705+ if (use_f16) {
3706+ kernel = backend_ctx->kernel_soft_max_f16 ;
3707+ } else {
3708+ kernel = backend_ctx->kernel_soft_max ;
3709+ }
36973710 }
36983711
36993712 CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
0 commit comments