@@ -143,6 +143,7 @@ struct ggml_backend_opencl_context {
143143 cl_kernel kernel_rms_norm;
144144 cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8;
145145 cl_kernel kernel_soft_max, kernel_soft_max_4;
146+ cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16;
146147 cl_kernel kernel_get_rows_f32, kernel_get_rows_f16, kernel_get_rows_q4_0;
147148 cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16;
148149 cl_kernel kernel_cpy_f16_f16, kernel_cpy_f16_f32, kernel_cpy_f32_f16, kernel_cpy_f32_f32;
@@ -614,6 +615,8 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
614615 CL_CHECK ((backend_ctx->kernel_diag_mask_inf_8 = clCreateKernel (backend_ctx->program , " kernel_diag_mask_inf_8" , &err), err));
615616 CL_CHECK ((backend_ctx->kernel_soft_max = clCreateKernel (backend_ctx->program , " kernel_soft_max" , &err), err));
616617 CL_CHECK ((backend_ctx->kernel_soft_max_4 = clCreateKernel (backend_ctx->program , " kernel_soft_max_4" , &err), err));
618+ CL_CHECK ((backend_ctx->kernel_soft_max_f16 = clCreateKernel (backend_ctx->program , " kernel_soft_max_f16" , &err), err));
619+ CL_CHECK ((backend_ctx->kernel_soft_max_4_f16 = clCreateKernel (backend_ctx->program , " kernel_soft_max_4_f16" , &err), err));
617620 CL_CHECK ((backend_ctx->kernel_rope_norm_f32 = clCreateKernel (backend_ctx->program , " kernel_rope_norm_f32" , &err), err));
618621 CL_CHECK ((backend_ctx->kernel_rope_norm_f16 = clCreateKernel (backend_ctx->program , " kernel_rope_norm_f16" , &err), err));
619622 CL_CHECK ((backend_ctx->kernel_rope_neox_f32 = clCreateKernel (backend_ctx->program , " kernel_rope_neox_f32" , &err), err));
@@ -1044,8 +1047,16 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
10441047 return true ;
10451048 case GGML_OP_DIAG_MASK_INF:
10461049 return op->ne [3 ] == 1 ;
1047- case GGML_OP_ROPE:
1050+ case GGML_OP_ROPE: {
1051+ const int mode = ((const int32_t *) op->op_params )[2 ];
1052+ if (mode & GGML_ROPE_TYPE_MROPE) {
1053+ return false ;
1054+ }
1055+ if (mode & GGML_ROPE_TYPE_VISION) {
1056+ return false ;
1057+ }
10481058 return true ;
1059+ }
10491060 default :
10501061 return false ;
10511062 }
@@ -3666,6 +3677,8 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
36663677 const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
36673678 const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
36683679
3680+ const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
3681+
36693682 // Local size must be wave size. Each workgroup is a wave, working on a row,
36703683 // where a row corresponds to leading dimension.
36713684 int nth = MIN (32 , ne00);
@@ -3683,9 +3696,17 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
36833696 cl_kernel kernel;
36843697
36853698 if (ne00%4 == 0 ) {
3686- kernel = backend_ctx->kernel_soft_max_4 ;
3699+ if (use_f16) {
3700+ kernel = backend_ctx->kernel_soft_max_4_f16 ;
3701+ } else {
3702+ kernel = backend_ctx->kernel_soft_max_4 ;
3703+ }
36873704 } else {
3688- kernel = backend_ctx->kernel_soft_max ;
3705+ if (use_f16) {
3706+ kernel = backend_ctx->kernel_soft_max_f16 ;
3707+ } else {
3708+ kernel = backend_ctx->kernel_soft_max ;
3709+ }
36893710 }
36903711
36913712 CL_CHECK (clSetKernelArg (kernel, 0 , sizeof (cl_mem), &extra0->data_device ));
@@ -3766,7 +3787,8 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const
37663787 const int nb2 = dst ? dst->nb [2 ] : 0 ;
37673788 const int nb3 = dst ? dst->nb [3 ] : 0 ;
37683789
3769- GGML_ASSERT (ne10 == ne02);
3790+ GGML_ASSERT (ne10 % ne02 == 0 );
3791+ GGML_ASSERT (ne10 >= ne02);
37703792
37713793 int nth = MIN (64 , ne00);
37723794
0 commit comments