|
25 | 25 | #include <vector>
|
26 | 26 | #include <string>
|
27 | 27 | #include <cmath>
|
| 28 | +#include <map> |
28 | 29 | #include <memory>
|
29 | 30 | #include <charconv>
|
30 | 31 | #include <mutex>
|
@@ -424,6 +425,14 @@ struct ggml_backend_opencl_context {
|
424 | 425 | cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8;
|
425 | 426 | cl_kernel kernel_soft_max, kernel_soft_max_4;
|
426 | 427 | cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16;
|
| 428 | + std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f16; |
| 429 | + std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f16_q1; |
| 430 | + std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32; |
| 431 | + std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_q1; |
| 432 | + std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_f16; |
| 433 | + std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_f16_q1; |
| 434 | + std::map<std::pair<int, int>, int> kernels_flash_attn_bm; |
| 435 | + std::map<std::pair<int, int>, int> kernels_flash_attn_bn; |
427 | 436 | cl_kernel kernel_get_rows_f32, kernel_get_rows_f16, kernel_get_rows_q4_0;
|
428 | 437 | cl_kernel kernel_set_rows_f32, kernel_set_rows_f16;
|
429 | 438 | cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16;
|
@@ -1308,6 +1317,73 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
1308 | 1317 | GGML_LOG_CONT(".");
|
1309 | 1318 | }
|
1310 | 1319 |
|
| 1320 | + // flash_attn |
| 1321 | + { |
| 1322 | + #ifdef GGML_OPENCL_EMBED_KERNELS |
| 1323 | + const std::string kernel_src_f16 { |
| 1324 | + #include "flash_attn_f16.cl.h" |
| 1325 | + }; |
| 1326 | + const std::string kernel_src_f32 { |
| 1327 | + #include "flash_attn_f32.cl.h" |
| 1328 | + }; |
| 1329 | + const std::string kernel_src_f32_f16 { |
| 1330 | + #include "flash_attn_f32_f16.cl.h" |
| 1331 | + }; |
| 1332 | + #else |
| 1333 | + const std::string kernel_src_f16 = read_file("flash_attn_f16.cl"); |
| 1334 | + const std::string kernel_src_f32 = read_file("flash_attn_f32.cl"); |
| 1335 | + const std::string kernel_src_f32_f16 = read_file("flash_attn_f32_f16.cl"); |
| 1336 | + #endif |
| 1337 | + |
| 1338 | + if (!kernel_src_f16.empty() && !kernel_src_f32.empty() && !kernel_src_f32_f16.empty()) { |
| 1339 | + const struct { int dk; int dv; int bm; int bn; } fa_dims[] = { |
| 1340 | + { 64, 64, 64, 64}, { 80, 80, 64, 32}, { 96, 96, 64, 32}, |
| 1341 | + {112, 112, 32, 32}, {128, 128, 32, 32}, {192, 128, 16, 16}, |
| 1342 | + {192, 192, 16, 16}, {256, 256, 16, 16}, |
| 1343 | + }; |
| 1344 | + |
| 1345 | + for (size_t i = 0; i < sizeof(fa_dims)/sizeof(fa_dims[0]); ++i) { |
| 1346 | + const int dk = fa_dims[i].dk; |
| 1347 | + const int dv = fa_dims[i].dv; |
| 1348 | + const int bm = fa_dims[i].bm; |
| 1349 | + const int bn = fa_dims[i].bn; |
| 1350 | + std::string OPTS = compile_opts + |
| 1351 | + " -D DK=" + std::to_string(dk) + |
| 1352 | + " -D DV=" + std::to_string(dv) + |
| 1353 | + " -D BLOCK_M=" + std::to_string(bm) + |
| 1354 | + " -D BLOCK_N=" + std::to_string(bn); |
| 1355 | + |
| 1356 | + cl_program prog_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f16.c_str(), OPTS); |
| 1357 | + cl_kernel k_f16, k_f16_q1; |
| 1358 | + CL_CHECK((k_f16 = clCreateKernel(prog_f16, "flash_attn_f16", &err), err)); |
| 1359 | + CL_CHECK((k_f16_q1 = clCreateKernel(prog_f16, "flash_attn_f16_q1", &err), err)); |
| 1360 | + backend_ctx->kernels_flash_attn_f16[{dk, dv}] = k_f16; |
| 1361 | + backend_ctx->kernels_flash_attn_f16_q1[{dk, dv}] = k_f16_q1; |
| 1362 | + CL_CHECK(clReleaseProgram(prog_f16)); |
| 1363 | + |
| 1364 | + cl_program prog_f32 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32.c_str(), OPTS); |
| 1365 | + cl_kernel k_f32, k_f32_q1; |
| 1366 | + CL_CHECK((k_f32 = clCreateKernel(prog_f32, "flash_attn_f32", &err), err)); |
| 1367 | + CL_CHECK((k_f32_q1 = clCreateKernel(prog_f32, "flash_attn_f32_q1", &err), err)); |
| 1368 | + backend_ctx->kernels_flash_attn_f32[{dk, dv}] = k_f32; |
| 1369 | + backend_ctx->kernels_flash_attn_f32_q1[{dk, dv}] = k_f32_q1; |
| 1370 | + CL_CHECK(clReleaseProgram(prog_f32)); |
| 1371 | + |
| 1372 | + cl_program prog_f32_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32_f16.c_str(), OPTS); |
| 1373 | + cl_kernel k_f32_f16, k_f32_f16_q1; |
| 1374 | + CL_CHECK((k_f32_f16 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16", &err), err)); |
| 1375 | + CL_CHECK((k_f32_f16_q1 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16_q1", &err), err)); |
| 1376 | + backend_ctx->kernels_flash_attn_f32_f16[{dk, dv}] = k_f32_f16; |
| 1377 | + backend_ctx->kernels_flash_attn_f32_f16_q1[{dk, dv}] = k_f32_f16_q1; |
| 1378 | + CL_CHECK(clReleaseProgram(prog_f32_f16)); |
| 1379 | + |
| 1380 | + backend_ctx->kernels_flash_attn_bm[{dk, dv}] = bm; |
| 1381 | + backend_ctx->kernels_flash_attn_bn[{dk, dv}] = bn; |
| 1382 | + } |
| 1383 | + GGML_LOG_CONT("."); |
| 1384 | + } |
| 1385 | + } |
| 1386 | + |
1311 | 1387 | // argsort
|
1312 | 1388 | {
|
1313 | 1389 | #ifdef GGML_OPENCL_EMBED_KERNELS
|
@@ -2636,6 +2712,45 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
2636 | 2712 | return op->src[0]->type == GGML_TYPE_F32;
|
2637 | 2713 | case GGML_OP_SUM_ROWS:
|
2638 | 2714 | return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
|
| 2715 | + case GGML_OP_FLASH_ATTN_EXT: |
| 2716 | + { |
| 2717 | + if (op->src[4]) { |
| 2718 | + return false; |
| 2719 | + } |
| 2720 | + |
| 2721 | + const ggml_tensor * q = op->src[0]; |
| 2722 | + const ggml_tensor * k = op->src[1]; |
| 2723 | + const ggml_tensor * v = op->src[2]; |
| 2724 | + |
| 2725 | + const int dk = q->ne[0]; |
| 2726 | + const int dv = v->ne[0]; |
| 2727 | + |
| 2728 | + const struct { int dk; int dv; } supported_dims[] = { |
| 2729 | + { 64, 64}, { 80, 80}, { 96, 96}, |
| 2730 | + {112, 112}, {128, 128}, {192, 128}, |
| 2731 | + {192, 192}, {256, 256}, |
| 2732 | + }; |
| 2733 | + |
| 2734 | + bool dims_supported = false; |
| 2735 | + for (size_t i = 0; i < sizeof(supported_dims)/sizeof(supported_dims[0]); ++i) { |
| 2736 | + if (supported_dims[i].dk == dk && supported_dims[i].dv == dv) { |
| 2737 | + dims_supported = true; |
| 2738 | + break; |
| 2739 | + } |
| 2740 | + } |
| 2741 | + if (!dims_supported) { |
| 2742 | + return false; |
| 2743 | + } |
| 2744 | + |
| 2745 | + const bool is_f32_f32 = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F32 && |
| 2746 | + v->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; |
| 2747 | + const bool is_f16_f16 = q->type == GGML_TYPE_F16 && k->type == GGML_TYPE_F16 && |
| 2748 | + v->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16; |
| 2749 | + const bool is_f32_f16 = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16 && |
| 2750 | + v->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F32; |
| 2751 | + |
| 2752 | + return is_f32_f32 || is_f16_f16 || is_f32_f16; |
| 2753 | + } |
2639 | 2754 | default:
|
2640 | 2755 | return false;
|
2641 | 2756 | }
|
@@ -5451,6 +5566,133 @@ static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor
|
5451 | 5566 | backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, NULL, dst);
|
5452 | 5567 | }
|
5453 | 5568 |
|
| 5569 | +static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, const ggml_tensor * k, ggml_tensor * dst) { |
| 5570 | + const ggml_tensor * v = dst->src[2]; |
| 5571 | + const ggml_tensor * mask = dst->src[3]; |
| 5572 | + GGML_ASSERT(q->extra); |
| 5573 | + GGML_ASSERT(k->extra); |
| 5574 | + GGML_ASSERT(v->extra); |
| 5575 | + GGML_ASSERT(dst->extra); |
| 5576 | + if (mask) { |
| 5577 | + GGML_ASSERT(mask->extra); |
| 5578 | + } |
| 5579 | + |
| 5580 | + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; |
| 5581 | + |
| 5582 | + const int n_q = q->ne[1]; |
| 5583 | + const int n_kv = k->ne[1]; |
| 5584 | + const int d_head_q = q->ne[0]; |
| 5585 | + const int d_head_v = v->ne[0]; |
| 5586 | + const int n_head = q->ne[2]; |
| 5587 | + const int n_head_kv = k->ne[2]; |
| 5588 | + const int n_batch = q->ne[3]; |
| 5589 | + |
| 5590 | + cl_kernel kernel = NULL; |
| 5591 | + |
| 5592 | + const bool is_f16 = q->type == GGML_TYPE_F16; |
| 5593 | + const bool is_mixed = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16; |
| 5594 | + const std::pair<int, int> dk_dv = {d_head_q, d_head_v}; |
| 5595 | + |
| 5596 | + if (n_q == 1) { |
| 5597 | + if (is_mixed) { |
| 5598 | + kernel = backend_ctx->kernels_flash_attn_f32_f16_q1.at(dk_dv); |
| 5599 | + } else if (is_f16) { |
| 5600 | + kernel = backend_ctx->kernels_flash_attn_f16_q1.at(dk_dv); |
| 5601 | + } else { |
| 5602 | + kernel = backend_ctx->kernels_flash_attn_f32_q1.at(dk_dv); |
| 5603 | + } |
| 5604 | + } else { |
| 5605 | + if (is_mixed) { |
| 5606 | + kernel = backend_ctx->kernels_flash_attn_f32_f16.at(dk_dv); |
| 5607 | + } else if (is_f16) { |
| 5608 | + kernel = backend_ctx->kernels_flash_attn_f16.at(dk_dv); |
| 5609 | + } else { |
| 5610 | + kernel = backend_ctx->kernels_flash_attn_f32.at(dk_dv); |
| 5611 | + } |
| 5612 | + } |
| 5613 | + GGML_ASSERT(kernel != NULL); |
| 5614 | + |
| 5615 | + ggml_tensor_extra_cl * extra_q = (ggml_tensor_extra_cl *)q->extra; |
| 5616 | + ggml_tensor_extra_cl * extra_k = (ggml_tensor_extra_cl *)k->extra; |
| 5617 | + ggml_tensor_extra_cl * extra_v = (ggml_tensor_extra_cl *)v->extra; |
| 5618 | + ggml_tensor_extra_cl * extra_o = (ggml_tensor_extra_cl *)dst->extra; |
| 5619 | + ggml_tensor_extra_cl * extra_mask = mask ? (ggml_tensor_extra_cl *)mask->extra : NULL; |
| 5620 | + |
| 5621 | + cl_ulong offset_q = extra_q->offset + q->view_offs; |
| 5622 | + cl_ulong offset_k = extra_k->offset + k->view_offs; |
| 5623 | + cl_ulong offset_v = extra_v->offset + v->view_offs; |
| 5624 | + cl_ulong offset_o = extra_o->offset + dst->view_offs; |
| 5625 | + cl_mem mask_buffer = extra_mask ? extra_mask->data_device : NULL; |
| 5626 | + cl_ulong offset_mask = extra_mask ? extra_mask->offset + mask->view_offs : 0; |
| 5627 | + |
| 5628 | + const cl_ulong q_nb1 = q->nb[1], q_nb2 = q->nb[2], q_nb3 = q->nb[3]; |
| 5629 | + const cl_ulong k_nb1 = k->nb[1], k_nb2 = k->nb[2], k_nb3 = k->nb[3]; |
| 5630 | + const cl_ulong v_nb1 = v->nb[1], v_nb2 = v->nb[2], v_nb3 = v->nb[3]; |
| 5631 | + const cl_ulong o_nb1 = dst->nb[1], o_nb2 = dst->nb[2], o_nb3 = dst->nb[3]; |
| 5632 | + const cl_ulong mask_nb1 = mask ? mask->nb[1] : 0; |
| 5633 | + const cl_ulong mask_nb2 = mask ? mask->nb[2] : 0; |
| 5634 | + const cl_ulong mask_nb3 = mask ? mask->nb[3] : 0; |
| 5635 | + const int mask_ne2 = mask ? mask->ne[2] : 0; |
| 5636 | + const int mask_ne3 = mask ? mask->ne[3] : 0; |
| 5637 | + |
| 5638 | + float scale, max_bias, logit_softcap; |
| 5639 | + const float * params = (const float *)dst->op_params; |
| 5640 | + scale = params[0]; |
| 5641 | + max_bias = params[1]; |
| 5642 | + logit_softcap = params[2]; |
| 5643 | + |
| 5644 | + const int is_causal = (mask == NULL && n_q > 1 && n_q == n_kv); |
| 5645 | + |
| 5646 | + const int n_head_log2_val = n_head > 0 ? 1u << (int)floorf(log2f((float)n_head)) : 0; |
| 5647 | + const float n_head_log2_f = n_head_log2_val > 0 ? (float)n_head_log2_val : 1.0f; |
| 5648 | + const float m0 = powf(2.0f, -(max_bias) / n_head_log2_f); |
| 5649 | + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2_f); |
| 5650 | + |
| 5651 | + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_q->data_device)); |
| 5652 | + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset_q)); |
| 5653 | + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_k->data_device)); |
| 5654 | + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset_k)); |
| 5655 | + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra_v->data_device)); |
| 5656 | + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset_v)); |
| 5657 | + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extra_o->data_device)); |
| 5658 | + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offset_o)); |
| 5659 | + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(float), &scale)); |
| 5660 | + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &n_q)); |
| 5661 | + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &n_kv)); |
| 5662 | + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &is_causal)); |
| 5663 | + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &n_head)); |
| 5664 | + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &q_nb1)); CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &q_nb2)); CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &q_nb3)); |
| 5665 | + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &k_nb1)); CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &k_nb2)); CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &k_nb3)); |
| 5666 | + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &v_nb1)); CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &v_nb2)); CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &v_nb3)); |
| 5667 | + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &o_nb1)); CL_CHECK(clSetKernelArg(kernel, 23, sizeof(cl_ulong), &o_nb2)); CL_CHECK(clSetKernelArg(kernel, 24, sizeof(cl_ulong), &o_nb3)); |
| 5668 | + CL_CHECK(clSetKernelArg(kernel, 25, sizeof(float), &max_bias)); |
| 5669 | + CL_CHECK(clSetKernelArg(kernel, 26, sizeof(float), &m0)); |
| 5670 | + CL_CHECK(clSetKernelArg(kernel, 27, sizeof(float), &m1)); |
| 5671 | + CL_CHECK(clSetKernelArg(kernel, 28, sizeof(int), &n_head_log2_val)); |
| 5672 | + CL_CHECK(clSetKernelArg(kernel, 29, sizeof(float), &logit_softcap)); |
| 5673 | + CL_CHECK(clSetKernelArg(kernel, 30, sizeof(int), &n_head_kv)); |
| 5674 | + CL_CHECK(clSetKernelArg(kernel, 31, sizeof(cl_mem), &mask_buffer)); |
| 5675 | + CL_CHECK(clSetKernelArg(kernel, 32, sizeof(cl_ulong), &offset_mask)); |
| 5676 | + CL_CHECK(clSetKernelArg(kernel, 33, sizeof(cl_ulong), &mask_nb1)); |
| 5677 | + CL_CHECK(clSetKernelArg(kernel, 34, sizeof(cl_ulong), &mask_nb2)); |
| 5678 | + CL_CHECK(clSetKernelArg(kernel, 35, sizeof(cl_ulong), &mask_nb3)); |
| 5679 | + CL_CHECK(clSetKernelArg(kernel, 36, sizeof(int), &mask_ne2)); |
| 5680 | + CL_CHECK(clSetKernelArg(kernel, 37, sizeof(int), &mask_ne3)); |
| 5681 | + |
| 5682 | + if (n_q == 1) { |
| 5683 | + const size_t wg_size = 64; |
| 5684 | + size_t local_work_size[] = { wg_size, 1 }; |
| 5685 | + size_t global_work_size[] = { wg_size, (size_t)(n_head * n_batch) }; |
| 5686 | + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst); |
| 5687 | + } else { |
| 5688 | + const int block_m = backend_ctx->kernels_flash_attn_bm.at(dk_dv); |
| 5689 | + const size_t wg_size = block_m; |
| 5690 | + size_t local_work_size[] = { wg_size, 1 }; |
| 5691 | + size_t global_work_size[] = { (size_t)((n_q + block_m - 1) / block_m) * wg_size, (size_t)(n_head * n_batch) }; |
| 5692 | + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst); |
| 5693 | + } |
| 5694 | +} |
| 5695 | + |
5454 | 5696 | static void ggml_cl_mul_mat_f16_f32_tiled(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
5455 | 5697 | ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
5456 | 5698 |
|
@@ -7607,6 +7849,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
|
7607 | 7849 | }
|
7608 | 7850 | func = ggml_cl_sum_rows;
|
7609 | 7851 | break;
|
| 7852 | + case GGML_OP_FLASH_ATTN_EXT: |
| 7853 | + if (!any_on_device) { |
| 7854 | + return false; |
| 7855 | + } |
| 7856 | + ggml_cl_flash_attn(backend, tensor->src[0], tensor->src[1], tensor); |
| 7857 | + return true; |
7610 | 7858 | default:
|
7611 | 7859 | return false;
|
7612 | 7860 | }
|
|
0 commit comments