@@ -428,6 +428,7 @@ struct ggml_backend_opencl_context {
428428 std::map<std::pair<int , int >, cl_kernel> kernels_flash_attn_f32_f16;
429429 std::map<std::pair<int , int >, cl_kernel> kernels_flash_attn_f32_f16_q1;
430430 std::map<std::pair<int , int >, int > kernels_flash_attn_bm;
431+ std::map<std::pair<int , int >, int > kernels_flash_attn_bn;
431432 cl_kernel kernel_get_rows_f32, kernel_get_rows_f16, kernel_get_rows_q4_0;
432433 cl_kernel kernel_set_rows_f32, kernel_set_rows_f16;
433434 cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16;
@@ -1311,7 +1312,6 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
13111312 cl_kernel k_f16, k_f16_q1;
13121313 CL_CHECK ((k_f16 = clCreateKernel (prog_f16, " flash_attn_f16" , &err), err));
13131314 CL_CHECK ((k_f16_q1 = clCreateKernel (prog_f16, " flash_attn_f16_q1" , &err), err));
1314- GGML_ASSERT (k_f16 != NULL && k_f16_q1 != NULL );
13151315 backend_ctx->kernels_flash_attn_f16 [{dk, dv}] = k_f16;
13161316 backend_ctx->kernels_flash_attn_f16_q1 [{dk, dv}] = k_f16_q1;
13171317 CL_CHECK (clReleaseProgram (prog_f16));
@@ -1320,7 +1320,6 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
13201320 cl_kernel k_f32, k_f32_q1;
13211321 CL_CHECK ((k_f32 = clCreateKernel (prog_f32, " flash_attn_f32" , &err), err));
13221322 CL_CHECK ((k_f32_q1 = clCreateKernel (prog_f32, " flash_attn_f32_q1" , &err), err));
1323- GGML_ASSERT (k_f32 != NULL && k_f32_q1 != NULL );
13241323 backend_ctx->kernels_flash_attn_f32 [{dk, dv}] = k_f32;
13251324 backend_ctx->kernels_flash_attn_f32_q1 [{dk, dv}] = k_f32_q1;
13261325 CL_CHECK (clReleaseProgram (prog_f32));
@@ -1329,12 +1328,12 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
13291328 cl_kernel k_f32_f16, k_f32_f16_q1;
13301329 CL_CHECK ((k_f32_f16 = clCreateKernel (prog_f32_f16, " flash_attn_f32_f16" , &err), err));
13311330 CL_CHECK ((k_f32_f16_q1 = clCreateKernel (prog_f32_f16, " flash_attn_f32_f16_q1" , &err), err));
1332- GGML_ASSERT (k_f32_f16 != NULL && k_f32_f16_q1 != NULL );
13331331 backend_ctx->kernels_flash_attn_f32_f16 [{dk, dv}] = k_f32_f16;
13341332 backend_ctx->kernels_flash_attn_f32_f16_q1 [{dk, dv}] = k_f32_f16_q1;
13351333 CL_CHECK (clReleaseProgram (prog_f32_f16));
13361334
13371335 backend_ctx->kernels_flash_attn_bm [{dk, dv}] = bm;
1336+ backend_ctx->kernels_flash_attn_bn [{dk, dv}] = bn;
13381337 }
13391338 GGML_LOG_CONT (" ." );
13401339 }
0 commit comments