|
25 | 25 | #include <vector> |
26 | 26 | #include <string> |
27 | 27 | #include <cmath> |
| 28 | +#include <map> |
28 | 29 | #include <memory> |
29 | 30 | #include <charconv> |
30 | 31 | #include <mutex> |
@@ -420,6 +421,13 @@ struct ggml_backend_opencl_context { |
420 | 421 | cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8; |
421 | 422 | cl_kernel kernel_soft_max, kernel_soft_max_4; |
422 | 423 | cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16; |
| 424 | + std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f16; |
| 425 | + std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f16_q1; |
| 426 | + std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32; |
| 427 | + std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_q1; |
| 428 | + std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_f16; |
| 429 | + std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_f16_q1; |
| 430 | + std::map<std::pair<int, int>, int> kernels_flash_attn_bm; |
423 | 431 | cl_kernel kernel_get_rows_f32, kernel_get_rows_f16, kernel_get_rows_q4_0; |
424 | 432 | cl_kernel kernel_set_rows_f32, kernel_set_rows_f16; |
425 | 433 | cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16; |
@@ -1263,6 +1271,75 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve |
1263 | 1271 | GGML_LOG_CONT("."); |
1264 | 1272 | } |
1265 | 1273 |
|
| 1274 | + // flash_attn |
| 1275 | + { |
| 1276 | + #ifdef GGML_OPENCL_EMBED_KERNELS |
| 1277 | + const std::string kernel_src_f16 { |
| 1278 | + #include "flash_attn_f16.cl.h" |
| 1279 | + }; |
| 1280 | + const std::string kernel_src_f32 { |
| 1281 | + #include "flash_attn_f32.cl.h" |
| 1282 | + }; |
| 1283 | + const std::string kernel_src_f32_f16 { |
| 1284 | + #include "flash_attn_f32_f16.cl.h" |
| 1285 | + }; |
| 1286 | + #else |
| 1287 | + const std::string kernel_src_f16 = read_file("flash_attn_f16.cl"); |
| 1288 | + const std::string kernel_src_f32 = read_file("flash_attn_f32.cl"); |
| 1289 | + const std::string kernel_src_f32_f16 = read_file("flash_attn_f32_f16.cl"); |
| 1290 | + #endif |
| 1291 | + |
| 1292 | + if (!kernel_src_f16.empty() && !kernel_src_f32.empty() && !kernel_src_f32_f16.empty()) { |
| 1293 | + const struct { int dk; int dv; int bm; int bn; } fa_dims[] = { |
| 1294 | + { 64, 64, 64, 64}, { 80, 80, 64, 32}, { 96, 96, 64, 32}, |
| 1295 | + {112, 112, 32, 32}, {128, 128, 32, 32}, {192, 128, 16, 16}, |
| 1296 | + {192, 192, 16, 16}, {256, 256, 16, 16}, |
| 1297 | + }; |
| 1298 | + |
| 1299 | + for (size_t i = 0; i < sizeof(fa_dims)/sizeof(fa_dims[0]); ++i) { |
| 1300 | + const int dk = fa_dims[i].dk; |
| 1301 | + const int dv = fa_dims[i].dv; |
| 1302 | + const int bm = fa_dims[i].bm; |
| 1303 | + const int bn = fa_dims[i].bn; |
| 1304 | + std::string OPTS = compile_opts + |
| 1305 | + " -D DK=" + std::to_string(dk) + |
| 1306 | + " -D DV=" + std::to_string(dv) + |
| 1307 | + " -D BLOCK_M=" + std::to_string(bm) + |
| 1308 | + " -D BLOCK_N=" + std::to_string(bn); |
| 1309 | + |
| 1310 | + cl_program prog_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f16.c_str(), OPTS); |
| 1311 | + cl_kernel k_f16, k_f16_q1; |
| 1312 | + CL_CHECK((k_f16 = clCreateKernel(prog_f16, "flash_attn_f16", &err), err)); |
| 1313 | + CL_CHECK((k_f16_q1 = clCreateKernel(prog_f16, "flash_attn_f16_q1", &err), err)); |
| 1314 | + GGML_ASSERT(k_f16 != NULL && k_f16_q1 != NULL); |
| 1315 | + backend_ctx->kernels_flash_attn_f16[{dk, dv}] = k_f16; |
| 1316 | + backend_ctx->kernels_flash_attn_f16_q1[{dk, dv}] = k_f16_q1; |
| 1317 | + CL_CHECK(clReleaseProgram(prog_f16)); |
| 1318 | + |
| 1319 | + cl_program prog_f32 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32.c_str(), OPTS); |
| 1320 | + cl_kernel k_f32, k_f32_q1; |
| 1321 | + CL_CHECK((k_f32 = clCreateKernel(prog_f32, "flash_attn_f32", &err), err)); |
| 1322 | + CL_CHECK((k_f32_q1 = clCreateKernel(prog_f32, "flash_attn_f32_q1", &err), err)); |
| 1323 | + GGML_ASSERT(k_f32 != NULL && k_f32_q1 != NULL); |
| 1324 | + backend_ctx->kernels_flash_attn_f32[{dk, dv}] = k_f32; |
| 1325 | + backend_ctx->kernels_flash_attn_f32_q1[{dk, dv}] = k_f32_q1; |
| 1326 | + CL_CHECK(clReleaseProgram(prog_f32)); |
| 1327 | + |
| 1328 | + cl_program prog_f32_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32_f16.c_str(), OPTS); |
| 1329 | + cl_kernel k_f32_f16, k_f32_f16_q1; |
| 1330 | + CL_CHECK((k_f32_f16 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16", &err), err)); |
| 1331 | + CL_CHECK((k_f32_f16_q1 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16_q1", &err), err)); |
| 1332 | + GGML_ASSERT(k_f32_f16 != NULL && k_f32_f16_q1 != NULL); |
| 1333 | + backend_ctx->kernels_flash_attn_f32_f16[{dk, dv}] = k_f32_f16; |
| 1334 | + backend_ctx->kernels_flash_attn_f32_f16_q1[{dk, dv}] = k_f32_f16_q1; |
| 1335 | + CL_CHECK(clReleaseProgram(prog_f32_f16)); |
| 1336 | + |
| 1337 | + backend_ctx->kernels_flash_attn_bm[{dk, dv}] = bm; |
| 1338 | + } |
| 1339 | + GGML_LOG_CONT("."); |
| 1340 | + } |
| 1341 | + } |
| 1342 | + |
1266 | 1343 | // argsort |
1267 | 1344 | { |
1268 | 1345 | #ifdef GGML_OPENCL_EMBED_KERNELS |
@@ -2553,6 +2630,41 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te |
2553 | 2630 | return op->src[0]->type == GGML_TYPE_F32; |
2554 | 2631 | case GGML_OP_SUM_ROWS: |
2555 | 2632 | return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]); |
| 2633 | + case GGML_OP_FLASH_ATTN_EXT: |
| 2634 | + { |
| 2635 | + const ggml_tensor * q = op->src[0]; |
| 2636 | + const ggml_tensor * k = op->src[1]; |
| 2637 | + const ggml_tensor * v = op->src[2]; |
| 2638 | + |
| 2639 | + const int dk = q->ne[0]; |
| 2640 | + const int dv = v->ne[0]; |
| 2641 | + |
| 2642 | + const struct { int dk; int dv; } supported_dims[] = { |
| 2643 | + { 64, 64}, { 80, 80}, { 96, 96}, |
| 2644 | + {112, 112}, {128, 128}, {192, 128}, |
| 2645 | + {192, 192}, {256, 256}, |
| 2646 | + }; |
| 2647 | + |
| 2648 | + bool dims_supported = false; |
| 2649 | + for (size_t i = 0; i < sizeof(supported_dims)/sizeof(supported_dims[0]); ++i) { |
| 2650 | + if (supported_dims[i].dk == dk && supported_dims[i].dv == dv) { |
| 2651 | + dims_supported = true; |
| 2652 | + break; |
| 2653 | + } |
| 2654 | + } |
| 2655 | + if (!dims_supported) { |
| 2656 | + return false; |
| 2657 | + } |
| 2658 | + |
| 2659 | + const bool is_f32_f32 = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F32 && |
| 2660 | + v->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; |
| 2661 | + const bool is_f16_f16 = q->type == GGML_TYPE_F16 && k->type == GGML_TYPE_F16 && |
| 2662 | + v->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16; |
| 2663 | + const bool is_f32_f16 = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16 && |
| 2664 | + v->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F32; |
| 2665 | + |
| 2666 | + return is_f32_f32 || is_f16_f16 || is_f32_f16; |
| 2667 | + } |
2556 | 2668 | default: |
2557 | 2669 | return false; |
2558 | 2670 | } |
@@ -5193,6 +5305,133 @@ static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor |
5193 | 5305 | backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, NULL, dst); |
5194 | 5306 | } |
5195 | 5307 |
|
| 5308 | +static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, const ggml_tensor * k, ggml_tensor * dst) { |
| 5309 | + const ggml_tensor * v = dst->src[2]; |
| 5310 | + const ggml_tensor * mask = dst->src[3]; |
| 5311 | + GGML_ASSERT(q->extra); |
| 5312 | + GGML_ASSERT(k->extra); |
| 5313 | + GGML_ASSERT(v->extra); |
| 5314 | + GGML_ASSERT(dst->extra); |
| 5315 | + if (mask) { |
| 5316 | + GGML_ASSERT(mask->extra); |
| 5317 | + } |
| 5318 | + |
| 5319 | + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; |
| 5320 | + |
| 5321 | + const int n_q = q->ne[1]; |
| 5322 | + const int n_kv = k->ne[1]; |
| 5323 | + const int d_head_q = q->ne[0]; |
| 5324 | + const int d_head_v = v->ne[0]; |
| 5325 | + const int n_head = q->ne[2]; |
| 5326 | + const int n_head_kv = k->ne[2]; |
| 5327 | + const int n_batch = q->ne[3]; |
| 5328 | + |
| 5329 | + cl_kernel kernel = NULL; |
| 5330 | + |
| 5331 | + const bool is_f16 = q->type == GGML_TYPE_F16; |
| 5332 | + const bool is_mixed = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16; |
| 5333 | + const std::pair<int, int> dk_dv = {d_head_q, d_head_v}; |
| 5334 | + |
| 5335 | + if (n_q == 1) { |
| 5336 | + if (is_mixed) { |
| 5337 | + kernel = backend_ctx->kernels_flash_attn_f32_f16_q1.at(dk_dv); |
| 5338 | + } else if (is_f16) { |
| 5339 | + kernel = backend_ctx->kernels_flash_attn_f16_q1.at(dk_dv); |
| 5340 | + } else { |
| 5341 | + kernel = backend_ctx->kernels_flash_attn_f32_q1.at(dk_dv); |
| 5342 | + } |
| 5343 | + } else { |
| 5344 | + if (is_mixed) { |
| 5345 | + kernel = backend_ctx->kernels_flash_attn_f32_f16.at(dk_dv); |
| 5346 | + } else if (is_f16) { |
| 5347 | + kernel = backend_ctx->kernels_flash_attn_f16.at(dk_dv); |
| 5348 | + } else { |
| 5349 | + kernel = backend_ctx->kernels_flash_attn_f32.at(dk_dv); |
| 5350 | + } |
| 5351 | + } |
| 5352 | + GGML_ASSERT(kernel != NULL); |
| 5353 | + |
| 5354 | + ggml_tensor_extra_cl * extra_q = (ggml_tensor_extra_cl *)q->extra; |
| 5355 | + ggml_tensor_extra_cl * extra_k = (ggml_tensor_extra_cl *)k->extra; |
| 5356 | + ggml_tensor_extra_cl * extra_v = (ggml_tensor_extra_cl *)v->extra; |
| 5357 | + ggml_tensor_extra_cl * extra_o = (ggml_tensor_extra_cl *)dst->extra; |
| 5358 | + ggml_tensor_extra_cl * extra_mask = mask ? (ggml_tensor_extra_cl *)mask->extra : NULL; |
| 5359 | + |
| 5360 | + cl_ulong offset_q = extra_q->offset + q->view_offs; |
| 5361 | + cl_ulong offset_k = extra_k->offset + k->view_offs; |
| 5362 | + cl_ulong offset_v = extra_v->offset + v->view_offs; |
| 5363 | + cl_ulong offset_o = extra_o->offset + dst->view_offs; |
| 5364 | + cl_mem mask_buffer = extra_mask ? extra_mask->data_device : NULL; |
| 5365 | + cl_ulong offset_mask = extra_mask ? extra_mask->offset + mask->view_offs : 0; |
| 5366 | + |
| 5367 | + const cl_ulong q_nb1 = q->nb[1], q_nb2 = q->nb[2], q_nb3 = q->nb[3]; |
| 5368 | + const cl_ulong k_nb1 = k->nb[1], k_nb2 = k->nb[2], k_nb3 = k->nb[3]; |
| 5369 | + const cl_ulong v_nb1 = v->nb[1], v_nb2 = v->nb[2], v_nb3 = v->nb[3]; |
| 5370 | + const cl_ulong o_nb1 = dst->nb[1], o_nb2 = dst->nb[2], o_nb3 = dst->nb[3]; |
| 5371 | + const cl_ulong mask_nb1 = mask ? mask->nb[1] : 0; |
| 5372 | + const cl_ulong mask_nb2 = mask ? mask->nb[2] : 0; |
| 5373 | + const cl_ulong mask_nb3 = mask ? mask->nb[3] : 0; |
| 5374 | + const int mask_ne2 = mask ? mask->ne[2] : 0; |
| 5375 | + const int mask_ne3 = mask ? mask->ne[3] : 0; |
| 5376 | + |
| 5377 | + float scale, max_bias, logit_softcap; |
| 5378 | + const float * params = (const float *)dst->op_params; |
| 5379 | + scale = params[0]; |
| 5380 | + max_bias = params[1]; |
| 5381 | + logit_softcap = params[2]; |
| 5382 | + |
| 5383 | + const int is_causal = (mask == NULL && n_q > 1 && n_q == n_kv); |
| 5384 | + |
| 5385 | + const int n_head_log2_val = n_head > 0 ? 1u << (int)floorf(log2f((float)n_head)) : 0; |
| 5386 | + const float n_head_log2_f = n_head_log2_val > 0 ? (float)n_head_log2_val : 1.0f; |
| 5387 | + const float m0 = powf(2.0f, -(max_bias) / n_head_log2_f); |
| 5388 | + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2_f); |
| 5389 | + |
| 5390 | + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_q->data_device)); |
| 5391 | + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset_q)); |
| 5392 | + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_k->data_device)); |
| 5393 | + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset_k)); |
| 5394 | + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra_v->data_device)); |
| 5395 | + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset_v)); |
| 5396 | + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extra_o->data_device)); |
| 5397 | + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offset_o)); |
| 5398 | + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(float), &scale)); |
| 5399 | + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &n_q)); |
| 5400 | + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &n_kv)); |
| 5401 | + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &is_causal)); |
| 5402 | + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &n_head)); |
| 5403 | + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &q_nb1)); CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &q_nb2)); CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &q_nb3)); |
| 5404 | + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &k_nb1)); CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &k_nb2)); CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &k_nb3)); |
| 5405 | + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &v_nb1)); CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &v_nb2)); CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &v_nb3)); |
| 5406 | + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &o_nb1)); CL_CHECK(clSetKernelArg(kernel, 23, sizeof(cl_ulong), &o_nb2)); CL_CHECK(clSetKernelArg(kernel, 24, sizeof(cl_ulong), &o_nb3)); |
| 5407 | + CL_CHECK(clSetKernelArg(kernel, 25, sizeof(float), &max_bias)); |
| 5408 | + CL_CHECK(clSetKernelArg(kernel, 26, sizeof(float), &m0)); |
| 5409 | + CL_CHECK(clSetKernelArg(kernel, 27, sizeof(float), &m1)); |
| 5410 | + CL_CHECK(clSetKernelArg(kernel, 28, sizeof(int), &n_head_log2_val)); |
| 5411 | + CL_CHECK(clSetKernelArg(kernel, 29, sizeof(float), &logit_softcap)); |
| 5412 | + CL_CHECK(clSetKernelArg(kernel, 30, sizeof(int), &n_head_kv)); |
| 5413 | + CL_CHECK(clSetKernelArg(kernel, 31, sizeof(cl_mem), &mask_buffer)); |
| 5414 | + CL_CHECK(clSetKernelArg(kernel, 32, sizeof(cl_ulong), &offset_mask)); |
| 5415 | + CL_CHECK(clSetKernelArg(kernel, 33, sizeof(cl_ulong), &mask_nb1)); |
| 5416 | + CL_CHECK(clSetKernelArg(kernel, 34, sizeof(cl_ulong), &mask_nb2)); |
| 5417 | + CL_CHECK(clSetKernelArg(kernel, 35, sizeof(cl_ulong), &mask_nb3)); |
| 5418 | + CL_CHECK(clSetKernelArg(kernel, 36, sizeof(int), &mask_ne2)); |
| 5419 | + CL_CHECK(clSetKernelArg(kernel, 37, sizeof(int), &mask_ne3)); |
| 5420 | + |
| 5421 | + if (n_q == 1) { |
| 5422 | + const size_t wg_size = 64; |
| 5423 | + size_t local_work_size[] = { wg_size, 1 }; |
| 5424 | + size_t global_work_size[] = { wg_size, (size_t)(n_head * n_batch) }; |
| 5425 | + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst); |
| 5426 | + } else { |
| 5427 | + const int block_m = backend_ctx->kernels_flash_attn_bm.at(dk_dv); |
| 5428 | + const size_t wg_size = block_m; |
| 5429 | + size_t local_work_size[] = { wg_size, 1 }; |
| 5430 | + size_t global_work_size[] = { (size_t)((n_q + block_m - 1) / block_m) * wg_size, (size_t)(n_head * n_batch) }; |
| 5431 | + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst); |
| 5432 | + } |
| 5433 | +} |
| 5434 | + |
5196 | 5435 | static void ggml_cl_mul_mat_f16_f32_tiled(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { |
5197 | 5436 | ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; |
5198 | 5437 |
|
@@ -7239,6 +7478,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor |
7239 | 7478 | } |
7240 | 7479 | func = ggml_cl_sum_rows; |
7241 | 7480 | break; |
| 7481 | + case GGML_OP_FLASH_ATTN_EXT: |
| 7482 | + if (!any_on_device) { |
| 7483 | + return false; |
| 7484 | + } |
| 7485 | + ggml_cl_flash_attn(backend, tensor->src[0], tensor->src[1], tensor); |
| 7486 | + return true; |
7242 | 7487 | default: |
7243 | 7488 | return false; |
7244 | 7489 | } |
|
0 commit comments