Skip to content

Commit d3f049b

Browse files
committed
fix kernel init
1 parent e585cba commit d3f049b

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,7 @@ struct ggml_backend_opencl_context {
428428
std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_f16;
429429
std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_f16_q1;
430430
std::map<std::pair<int, int>, int> kernels_flash_attn_bm;
431+
std::map<std::pair<int, int>, int> kernels_flash_attn_bn;
431432
cl_kernel kernel_get_rows_f32, kernel_get_rows_f16, kernel_get_rows_q4_0;
432433
cl_kernel kernel_set_rows_f32, kernel_set_rows_f16;
433434
cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16;
@@ -1311,7 +1312,6 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
13111312
cl_kernel k_f16, k_f16_q1;
13121313
CL_CHECK((k_f16 = clCreateKernel(prog_f16, "flash_attn_f16", &err), err));
13131314
CL_CHECK((k_f16_q1 = clCreateKernel(prog_f16, "flash_attn_f16_q1", &err), err));
1314-
GGML_ASSERT(k_f16 != NULL && k_f16_q1 != NULL);
13151315
backend_ctx->kernels_flash_attn_f16[{dk, dv}] = k_f16;
13161316
backend_ctx->kernels_flash_attn_f16_q1[{dk, dv}] = k_f16_q1;
13171317
CL_CHECK(clReleaseProgram(prog_f16));
@@ -1320,7 +1320,6 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
13201320
cl_kernel k_f32, k_f32_q1;
13211321
CL_CHECK((k_f32 = clCreateKernel(prog_f32, "flash_attn_f32", &err), err));
13221322
CL_CHECK((k_f32_q1 = clCreateKernel(prog_f32, "flash_attn_f32_q1", &err), err));
1323-
GGML_ASSERT(k_f32 != NULL && k_f32_q1 != NULL);
13241323
backend_ctx->kernels_flash_attn_f32[{dk, dv}] = k_f32;
13251324
backend_ctx->kernels_flash_attn_f32_q1[{dk, dv}] = k_f32_q1;
13261325
CL_CHECK(clReleaseProgram(prog_f32));
@@ -1329,12 +1328,12 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
13291328
cl_kernel k_f32_f16, k_f32_f16_q1;
13301329
CL_CHECK((k_f32_f16 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16", &err), err));
13311330
CL_CHECK((k_f32_f16_q1 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16_q1", &err), err));
1332-
GGML_ASSERT(k_f32_f16 != NULL && k_f32_f16_q1 != NULL);
13331331
backend_ctx->kernels_flash_attn_f32_f16[{dk, dv}] = k_f32_f16;
13341332
backend_ctx->kernels_flash_attn_f32_f16_q1[{dk, dv}] = k_f32_f16_q1;
13351333
CL_CHECK(clReleaseProgram(prog_f32_f16));
13361334

13371335
backend_ctx->kernels_flash_attn_bm[{dk, dv}] = bm;
1336+
backend_ctx->kernels_flash_attn_bn[{dk, dv}] = bn;
13381337
}
13391338
GGML_LOG_CONT(".");
13401339
}

0 commit comments

Comments
 (0)