Skip to content

Commit e585cba

Browse files
committed
add F16/F16 fa support
1 parent 8a4a856 commit e585cba

File tree

5 files changed

+1250
-0
lines changed

5 files changed

+1250
-0
lines changed

ggml/src/ggml-opencl/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,9 @@ set(GGML_OPENCL_KERNELS
109109
mul_mat_f16_f32
110110
conv2d
111111
conv2d_f16_f32
112+
flash_attn_f32_f16
113+
flash_attn_f16
114+
flash_attn_f32
112115
)
113116

114117
foreach (K ${GGML_OPENCL_KERNELS})

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include <vector>
2626
#include <string>
2727
#include <cmath>
28+
#include <map>
2829
#include <memory>
2930
#include <charconv>
3031
#include <mutex>
@@ -420,6 +421,13 @@ struct ggml_backend_opencl_context {
420421
cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8;
421422
cl_kernel kernel_soft_max, kernel_soft_max_4;
422423
cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16;
424+
std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f16;
425+
std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f16_q1;
426+
std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32;
427+
std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_q1;
428+
std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_f16;
429+
std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_f16_q1;
430+
std::map<std::pair<int, int>, int> kernels_flash_attn_bm;
423431
cl_kernel kernel_get_rows_f32, kernel_get_rows_f16, kernel_get_rows_q4_0;
424432
cl_kernel kernel_set_rows_f32, kernel_set_rows_f16;
425433
cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16;
@@ -1263,6 +1271,75 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
12631271
GGML_LOG_CONT(".");
12641272
}
12651273

1274+
// flash_attn
1275+
{
1276+
#ifdef GGML_OPENCL_EMBED_KERNELS
1277+
const std::string kernel_src_f16 {
1278+
#include "flash_attn_f16.cl.h"
1279+
};
1280+
const std::string kernel_src_f32 {
1281+
#include "flash_attn_f32.cl.h"
1282+
};
1283+
const std::string kernel_src_f32_f16 {
1284+
#include "flash_attn_f32_f16.cl.h"
1285+
};
1286+
#else
1287+
const std::string kernel_src_f16 = read_file("flash_attn_f16.cl");
1288+
const std::string kernel_src_f32 = read_file("flash_attn_f32.cl");
1289+
const std::string kernel_src_f32_f16 = read_file("flash_attn_f32_f16.cl");
1290+
#endif
1291+
1292+
if (!kernel_src_f16.empty() && !kernel_src_f32.empty() && !kernel_src_f32_f16.empty()) {
1293+
const struct { int dk; int dv; int bm; int bn; } fa_dims[] = {
1294+
{ 64, 64, 64, 64}, { 80, 80, 64, 32}, { 96, 96, 64, 32},
1295+
{112, 112, 32, 32}, {128, 128, 32, 32}, {192, 128, 16, 16},
1296+
{192, 192, 16, 16}, {256, 256, 16, 16},
1297+
};
1298+
1299+
for (size_t i = 0; i < sizeof(fa_dims)/sizeof(fa_dims[0]); ++i) {
1300+
const int dk = fa_dims[i].dk;
1301+
const int dv = fa_dims[i].dv;
1302+
const int bm = fa_dims[i].bm;
1303+
const int bn = fa_dims[i].bn;
1304+
std::string OPTS = compile_opts +
1305+
" -D DK=" + std::to_string(dk) +
1306+
" -D DV=" + std::to_string(dv) +
1307+
" -D BLOCK_M=" + std::to_string(bm) +
1308+
" -D BLOCK_N=" + std::to_string(bn);
1309+
1310+
cl_program prog_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f16.c_str(), OPTS);
1311+
cl_kernel k_f16, k_f16_q1;
1312+
CL_CHECK((k_f16 = clCreateKernel(prog_f16, "flash_attn_f16", &err), err));
1313+
CL_CHECK((k_f16_q1 = clCreateKernel(prog_f16, "flash_attn_f16_q1", &err), err));
1314+
GGML_ASSERT(k_f16 != NULL && k_f16_q1 != NULL);
1315+
backend_ctx->kernels_flash_attn_f16[{dk, dv}] = k_f16;
1316+
backend_ctx->kernels_flash_attn_f16_q1[{dk, dv}] = k_f16_q1;
1317+
CL_CHECK(clReleaseProgram(prog_f16));
1318+
1319+
cl_program prog_f32 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32.c_str(), OPTS);
1320+
cl_kernel k_f32, k_f32_q1;
1321+
CL_CHECK((k_f32 = clCreateKernel(prog_f32, "flash_attn_f32", &err), err));
1322+
CL_CHECK((k_f32_q1 = clCreateKernel(prog_f32, "flash_attn_f32_q1", &err), err));
1323+
GGML_ASSERT(k_f32 != NULL && k_f32_q1 != NULL);
1324+
backend_ctx->kernels_flash_attn_f32[{dk, dv}] = k_f32;
1325+
backend_ctx->kernels_flash_attn_f32_q1[{dk, dv}] = k_f32_q1;
1326+
CL_CHECK(clReleaseProgram(prog_f32));
1327+
1328+
cl_program prog_f32_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32_f16.c_str(), OPTS);
1329+
cl_kernel k_f32_f16, k_f32_f16_q1;
1330+
CL_CHECK((k_f32_f16 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16", &err), err));
1331+
CL_CHECK((k_f32_f16_q1 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16_q1", &err), err));
1332+
GGML_ASSERT(k_f32_f16 != NULL && k_f32_f16_q1 != NULL);
1333+
backend_ctx->kernels_flash_attn_f32_f16[{dk, dv}] = k_f32_f16;
1334+
backend_ctx->kernels_flash_attn_f32_f16_q1[{dk, dv}] = k_f32_f16_q1;
1335+
CL_CHECK(clReleaseProgram(prog_f32_f16));
1336+
1337+
backend_ctx->kernels_flash_attn_bm[{dk, dv}] = bm;
1338+
}
1339+
GGML_LOG_CONT(".");
1340+
}
1341+
}
1342+
12661343
// argsort
12671344
{
12681345
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -2553,6 +2630,41 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
25532630
return op->src[0]->type == GGML_TYPE_F32;
25542631
case GGML_OP_SUM_ROWS:
25552632
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
2633+
case GGML_OP_FLASH_ATTN_EXT:
2634+
{
2635+
const ggml_tensor * q = op->src[0];
2636+
const ggml_tensor * k = op->src[1];
2637+
const ggml_tensor * v = op->src[2];
2638+
2639+
const int dk = q->ne[0];
2640+
const int dv = v->ne[0];
2641+
2642+
const struct { int dk; int dv; } supported_dims[] = {
2643+
{ 64, 64}, { 80, 80}, { 96, 96},
2644+
{112, 112}, {128, 128}, {192, 128},
2645+
{192, 192}, {256, 256},
2646+
};
2647+
2648+
bool dims_supported = false;
2649+
for (size_t i = 0; i < sizeof(supported_dims)/sizeof(supported_dims[0]); ++i) {
2650+
if (supported_dims[i].dk == dk && supported_dims[i].dv == dv) {
2651+
dims_supported = true;
2652+
break;
2653+
}
2654+
}
2655+
if (!dims_supported) {
2656+
return false;
2657+
}
2658+
2659+
const bool is_f32_f32 = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F32 &&
2660+
v->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
2661+
const bool is_f16_f16 = q->type == GGML_TYPE_F16 && k->type == GGML_TYPE_F16 &&
2662+
v->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16;
2663+
const bool is_f32_f16 = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16 &&
2664+
v->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F32;
2665+
2666+
return is_f32_f32 || is_f16_f16 || is_f32_f16;
2667+
}
25562668
default:
25572669
return false;
25582670
}
@@ -5193,6 +5305,133 @@ static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor
51935305
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, NULL, dst);
51945306
}
51955307

5308+
static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, const ggml_tensor * k, ggml_tensor * dst) {
5309+
const ggml_tensor * v = dst->src[2];
5310+
const ggml_tensor * mask = dst->src[3];
5311+
GGML_ASSERT(q->extra);
5312+
GGML_ASSERT(k->extra);
5313+
GGML_ASSERT(v->extra);
5314+
GGML_ASSERT(dst->extra);
5315+
if (mask) {
5316+
GGML_ASSERT(mask->extra);
5317+
}
5318+
5319+
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
5320+
5321+
const int n_q = q->ne[1];
5322+
const int n_kv = k->ne[1];
5323+
const int d_head_q = q->ne[0];
5324+
const int d_head_v = v->ne[0];
5325+
const int n_head = q->ne[2];
5326+
const int n_head_kv = k->ne[2];
5327+
const int n_batch = q->ne[3];
5328+
5329+
cl_kernel kernel = NULL;
5330+
5331+
const bool is_f16 = q->type == GGML_TYPE_F16;
5332+
const bool is_mixed = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16;
5333+
const std::pair<int, int> dk_dv = {d_head_q, d_head_v};
5334+
5335+
if (n_q == 1) {
5336+
if (is_mixed) {
5337+
kernel = backend_ctx->kernels_flash_attn_f32_f16_q1.at(dk_dv);
5338+
} else if (is_f16) {
5339+
kernel = backend_ctx->kernels_flash_attn_f16_q1.at(dk_dv);
5340+
} else {
5341+
kernel = backend_ctx->kernels_flash_attn_f32_q1.at(dk_dv);
5342+
}
5343+
} else {
5344+
if (is_mixed) {
5345+
kernel = backend_ctx->kernels_flash_attn_f32_f16.at(dk_dv);
5346+
} else if (is_f16) {
5347+
kernel = backend_ctx->kernels_flash_attn_f16.at(dk_dv);
5348+
} else {
5349+
kernel = backend_ctx->kernels_flash_attn_f32.at(dk_dv);
5350+
}
5351+
}
5352+
GGML_ASSERT(kernel != NULL);
5353+
5354+
ggml_tensor_extra_cl * extra_q = (ggml_tensor_extra_cl *)q->extra;
5355+
ggml_tensor_extra_cl * extra_k = (ggml_tensor_extra_cl *)k->extra;
5356+
ggml_tensor_extra_cl * extra_v = (ggml_tensor_extra_cl *)v->extra;
5357+
ggml_tensor_extra_cl * extra_o = (ggml_tensor_extra_cl *)dst->extra;
5358+
ggml_tensor_extra_cl * extra_mask = mask ? (ggml_tensor_extra_cl *)mask->extra : NULL;
5359+
5360+
cl_ulong offset_q = extra_q->offset + q->view_offs;
5361+
cl_ulong offset_k = extra_k->offset + k->view_offs;
5362+
cl_ulong offset_v = extra_v->offset + v->view_offs;
5363+
cl_ulong offset_o = extra_o->offset + dst->view_offs;
5364+
cl_mem mask_buffer = extra_mask ? extra_mask->data_device : NULL;
5365+
cl_ulong offset_mask = extra_mask ? extra_mask->offset + mask->view_offs : 0;
5366+
5367+
const cl_ulong q_nb1 = q->nb[1], q_nb2 = q->nb[2], q_nb3 = q->nb[3];
5368+
const cl_ulong k_nb1 = k->nb[1], k_nb2 = k->nb[2], k_nb3 = k->nb[3];
5369+
const cl_ulong v_nb1 = v->nb[1], v_nb2 = v->nb[2], v_nb3 = v->nb[3];
5370+
const cl_ulong o_nb1 = dst->nb[1], o_nb2 = dst->nb[2], o_nb3 = dst->nb[3];
5371+
const cl_ulong mask_nb1 = mask ? mask->nb[1] : 0;
5372+
const cl_ulong mask_nb2 = mask ? mask->nb[2] : 0;
5373+
const cl_ulong mask_nb3 = mask ? mask->nb[3] : 0;
5374+
const int mask_ne2 = mask ? mask->ne[2] : 0;
5375+
const int mask_ne3 = mask ? mask->ne[3] : 0;
5376+
5377+
float scale, max_bias, logit_softcap;
5378+
const float * params = (const float *)dst->op_params;
5379+
scale = params[0];
5380+
max_bias = params[1];
5381+
logit_softcap = params[2];
5382+
5383+
const int is_causal = (mask == NULL && n_q > 1 && n_q == n_kv);
5384+
5385+
const int n_head_log2_val = n_head > 0 ? 1u << (int)floorf(log2f((float)n_head)) : 0;
5386+
const float n_head_log2_f = n_head_log2_val > 0 ? (float)n_head_log2_val : 1.0f;
5387+
const float m0 = powf(2.0f, -(max_bias) / n_head_log2_f);
5388+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2_f);
5389+
5390+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_q->data_device));
5391+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset_q));
5392+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_k->data_device));
5393+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset_k));
5394+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra_v->data_device));
5395+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset_v));
5396+
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extra_o->data_device));
5397+
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offset_o));
5398+
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(float), &scale));
5399+
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &n_q));
5400+
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &n_kv));
5401+
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &is_causal));
5402+
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &n_head));
5403+
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &q_nb1)); CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &q_nb2)); CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &q_nb3));
5404+
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &k_nb1)); CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &k_nb2)); CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &k_nb3));
5405+
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &v_nb1)); CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &v_nb2)); CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &v_nb3));
5406+
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &o_nb1)); CL_CHECK(clSetKernelArg(kernel, 23, sizeof(cl_ulong), &o_nb2)); CL_CHECK(clSetKernelArg(kernel, 24, sizeof(cl_ulong), &o_nb3));
5407+
CL_CHECK(clSetKernelArg(kernel, 25, sizeof(float), &max_bias));
5408+
CL_CHECK(clSetKernelArg(kernel, 26, sizeof(float), &m0));
5409+
CL_CHECK(clSetKernelArg(kernel, 27, sizeof(float), &m1));
5410+
CL_CHECK(clSetKernelArg(kernel, 28, sizeof(int), &n_head_log2_val));
5411+
CL_CHECK(clSetKernelArg(kernel, 29, sizeof(float), &logit_softcap));
5412+
CL_CHECK(clSetKernelArg(kernel, 30, sizeof(int), &n_head_kv));
5413+
CL_CHECK(clSetKernelArg(kernel, 31, sizeof(cl_mem), &mask_buffer));
5414+
CL_CHECK(clSetKernelArg(kernel, 32, sizeof(cl_ulong), &offset_mask));
5415+
CL_CHECK(clSetKernelArg(kernel, 33, sizeof(cl_ulong), &mask_nb1));
5416+
CL_CHECK(clSetKernelArg(kernel, 34, sizeof(cl_ulong), &mask_nb2));
5417+
CL_CHECK(clSetKernelArg(kernel, 35, sizeof(cl_ulong), &mask_nb3));
5418+
CL_CHECK(clSetKernelArg(kernel, 36, sizeof(int), &mask_ne2));
5419+
CL_CHECK(clSetKernelArg(kernel, 37, sizeof(int), &mask_ne3));
5420+
5421+
if (n_q == 1) {
5422+
const size_t wg_size = 64;
5423+
size_t local_work_size[] = { wg_size, 1 };
5424+
size_t global_work_size[] = { wg_size, (size_t)(n_head * n_batch) };
5425+
backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);
5426+
} else {
5427+
const int block_m = backend_ctx->kernels_flash_attn_bm.at(dk_dv);
5428+
const size_t wg_size = block_m;
5429+
size_t local_work_size[] = { wg_size, 1 };
5430+
size_t global_work_size[] = { (size_t)((n_q + block_m - 1) / block_m) * wg_size, (size_t)(n_head * n_batch) };
5431+
backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);
5432+
}
5433+
}
5434+
51965435
static void ggml_cl_mul_mat_f16_f32_tiled(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
51975436
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
51985437

@@ -7239,6 +7478,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
72397478
}
72407479
func = ggml_cl_sum_rows;
72417480
break;
7481+
case GGML_OP_FLASH_ATTN_EXT:
7482+
if (!any_on_device) {
7483+
return false;
7484+
}
7485+
ggml_cl_flash_attn(backend, tensor->src[0], tensor->src[1], tensor);
7486+
return true;
72427487
default:
72437488
return false;
72447489
}

0 commit comments

Comments
 (0)