Skip to content

Commit 912ff8c

Browse files
authored
OpenCL: add initial FA support (#14987)
* add F16/F16 fa support * fix kernel init * use mad instead of fma * use inline function * mark FA with sinks as unsupported for now * add pragma unroll to loops
1 parent 5e6229a commit 912ff8c

File tree

5 files changed

+1283
-0
lines changed

5 files changed

+1283
-0
lines changed

ggml/src/ggml-opencl/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,9 @@ set(GGML_OPENCL_KERNELS
112112
mul_mat_f16_f32
113113
conv2d
114114
conv2d_f16_f32
115+
flash_attn_f32_f16
116+
flash_attn_f16
117+
flash_attn_f32
115118
)
116119

117120
foreach (K ${GGML_OPENCL_KERNELS})

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include <vector>
2626
#include <string>
2727
#include <cmath>
28+
#include <map>
2829
#include <memory>
2930
#include <charconv>
3031
#include <mutex>
@@ -424,6 +425,14 @@ struct ggml_backend_opencl_context {
424425
cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8;
425426
cl_kernel kernel_soft_max, kernel_soft_max_4;
426427
cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16;
428+
std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f16;
429+
std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f16_q1;
430+
std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32;
431+
std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_q1;
432+
std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_f16;
433+
std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_f16_q1;
434+
std::map<std::pair<int, int>, int> kernels_flash_attn_bm;
435+
std::map<std::pair<int, int>, int> kernels_flash_attn_bn;
427436
cl_kernel kernel_get_rows_f32, kernel_get_rows_f16, kernel_get_rows_q4_0;
428437
cl_kernel kernel_set_rows_f32, kernel_set_rows_f16;
429438
cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16;
@@ -1308,6 +1317,73 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
13081317
GGML_LOG_CONT(".");
13091318
}
13101319

1320+
// flash_attn
1321+
{
1322+
#ifdef GGML_OPENCL_EMBED_KERNELS
1323+
const std::string kernel_src_f16 {
1324+
#include "flash_attn_f16.cl.h"
1325+
};
1326+
const std::string kernel_src_f32 {
1327+
#include "flash_attn_f32.cl.h"
1328+
};
1329+
const std::string kernel_src_f32_f16 {
1330+
#include "flash_attn_f32_f16.cl.h"
1331+
};
1332+
#else
1333+
const std::string kernel_src_f16 = read_file("flash_attn_f16.cl");
1334+
const std::string kernel_src_f32 = read_file("flash_attn_f32.cl");
1335+
const std::string kernel_src_f32_f16 = read_file("flash_attn_f32_f16.cl");
1336+
#endif
1337+
1338+
if (!kernel_src_f16.empty() && !kernel_src_f32.empty() && !kernel_src_f32_f16.empty()) {
1339+
const struct { int dk; int dv; int bm; int bn; } fa_dims[] = {
1340+
{ 64, 64, 64, 64}, { 80, 80, 64, 32}, { 96, 96, 64, 32},
1341+
{112, 112, 32, 32}, {128, 128, 32, 32}, {192, 128, 16, 16},
1342+
{192, 192, 16, 16}, {256, 256, 16, 16},
1343+
};
1344+
1345+
for (size_t i = 0; i < sizeof(fa_dims)/sizeof(fa_dims[0]); ++i) {
1346+
const int dk = fa_dims[i].dk;
1347+
const int dv = fa_dims[i].dv;
1348+
const int bm = fa_dims[i].bm;
1349+
const int bn = fa_dims[i].bn;
1350+
std::string OPTS = compile_opts +
1351+
" -D DK=" + std::to_string(dk) +
1352+
" -D DV=" + std::to_string(dv) +
1353+
" -D BLOCK_M=" + std::to_string(bm) +
1354+
" -D BLOCK_N=" + std::to_string(bn);
1355+
1356+
cl_program prog_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f16.c_str(), OPTS);
1357+
cl_kernel k_f16, k_f16_q1;
1358+
CL_CHECK((k_f16 = clCreateKernel(prog_f16, "flash_attn_f16", &err), err));
1359+
CL_CHECK((k_f16_q1 = clCreateKernel(prog_f16, "flash_attn_f16_q1", &err), err));
1360+
backend_ctx->kernels_flash_attn_f16[{dk, dv}] = k_f16;
1361+
backend_ctx->kernels_flash_attn_f16_q1[{dk, dv}] = k_f16_q1;
1362+
CL_CHECK(clReleaseProgram(prog_f16));
1363+
1364+
cl_program prog_f32 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32.c_str(), OPTS);
1365+
cl_kernel k_f32, k_f32_q1;
1366+
CL_CHECK((k_f32 = clCreateKernel(prog_f32, "flash_attn_f32", &err), err));
1367+
CL_CHECK((k_f32_q1 = clCreateKernel(prog_f32, "flash_attn_f32_q1", &err), err));
1368+
backend_ctx->kernels_flash_attn_f32[{dk, dv}] = k_f32;
1369+
backend_ctx->kernels_flash_attn_f32_q1[{dk, dv}] = k_f32_q1;
1370+
CL_CHECK(clReleaseProgram(prog_f32));
1371+
1372+
cl_program prog_f32_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32_f16.c_str(), OPTS);
1373+
cl_kernel k_f32_f16, k_f32_f16_q1;
1374+
CL_CHECK((k_f32_f16 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16", &err), err));
1375+
CL_CHECK((k_f32_f16_q1 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16_q1", &err), err));
1376+
backend_ctx->kernels_flash_attn_f32_f16[{dk, dv}] = k_f32_f16;
1377+
backend_ctx->kernels_flash_attn_f32_f16_q1[{dk, dv}] = k_f32_f16_q1;
1378+
CL_CHECK(clReleaseProgram(prog_f32_f16));
1379+
1380+
backend_ctx->kernels_flash_attn_bm[{dk, dv}] = bm;
1381+
backend_ctx->kernels_flash_attn_bn[{dk, dv}] = bn;
1382+
}
1383+
GGML_LOG_CONT(".");
1384+
}
1385+
}
1386+
13111387
// argsort
13121388
{
13131389
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -2636,6 +2712,45 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
26362712
return op->src[0]->type == GGML_TYPE_F32;
26372713
case GGML_OP_SUM_ROWS:
26382714
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
2715+
case GGML_OP_FLASH_ATTN_EXT:
2716+
{
2717+
if (op->src[4]) {
2718+
return false;
2719+
}
2720+
2721+
const ggml_tensor * q = op->src[0];
2722+
const ggml_tensor * k = op->src[1];
2723+
const ggml_tensor * v = op->src[2];
2724+
2725+
const int dk = q->ne[0];
2726+
const int dv = v->ne[0];
2727+
2728+
const struct { int dk; int dv; } supported_dims[] = {
2729+
{ 64, 64}, { 80, 80}, { 96, 96},
2730+
{112, 112}, {128, 128}, {192, 128},
2731+
{192, 192}, {256, 256},
2732+
};
2733+
2734+
bool dims_supported = false;
2735+
for (size_t i = 0; i < sizeof(supported_dims)/sizeof(supported_dims[0]); ++i) {
2736+
if (supported_dims[i].dk == dk && supported_dims[i].dv == dv) {
2737+
dims_supported = true;
2738+
break;
2739+
}
2740+
}
2741+
if (!dims_supported) {
2742+
return false;
2743+
}
2744+
2745+
const bool is_f32_f32 = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F32 &&
2746+
v->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
2747+
const bool is_f16_f16 = q->type == GGML_TYPE_F16 && k->type == GGML_TYPE_F16 &&
2748+
v->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16;
2749+
const bool is_f32_f16 = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16 &&
2750+
v->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F32;
2751+
2752+
return is_f32_f32 || is_f16_f16 || is_f32_f16;
2753+
}
26392754
default:
26402755
return false;
26412756
}
@@ -5451,6 +5566,133 @@ static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor
54515566
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, NULL, dst);
54525567
}
54535568

5569+
static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, const ggml_tensor * k, ggml_tensor * dst) {
5570+
const ggml_tensor * v = dst->src[2];
5571+
const ggml_tensor * mask = dst->src[3];
5572+
GGML_ASSERT(q->extra);
5573+
GGML_ASSERT(k->extra);
5574+
GGML_ASSERT(v->extra);
5575+
GGML_ASSERT(dst->extra);
5576+
if (mask) {
5577+
GGML_ASSERT(mask->extra);
5578+
}
5579+
5580+
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
5581+
5582+
const int n_q = q->ne[1];
5583+
const int n_kv = k->ne[1];
5584+
const int d_head_q = q->ne[0];
5585+
const int d_head_v = v->ne[0];
5586+
const int n_head = q->ne[2];
5587+
const int n_head_kv = k->ne[2];
5588+
const int n_batch = q->ne[3];
5589+
5590+
cl_kernel kernel = NULL;
5591+
5592+
const bool is_f16 = q->type == GGML_TYPE_F16;
5593+
const bool is_mixed = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16;
5594+
const std::pair<int, int> dk_dv = {d_head_q, d_head_v};
5595+
5596+
if (n_q == 1) {
5597+
if (is_mixed) {
5598+
kernel = backend_ctx->kernels_flash_attn_f32_f16_q1.at(dk_dv);
5599+
} else if (is_f16) {
5600+
kernel = backend_ctx->kernels_flash_attn_f16_q1.at(dk_dv);
5601+
} else {
5602+
kernel = backend_ctx->kernels_flash_attn_f32_q1.at(dk_dv);
5603+
}
5604+
} else {
5605+
if (is_mixed) {
5606+
kernel = backend_ctx->kernels_flash_attn_f32_f16.at(dk_dv);
5607+
} else if (is_f16) {
5608+
kernel = backend_ctx->kernels_flash_attn_f16.at(dk_dv);
5609+
} else {
5610+
kernel = backend_ctx->kernels_flash_attn_f32.at(dk_dv);
5611+
}
5612+
}
5613+
GGML_ASSERT(kernel != NULL);
5614+
5615+
ggml_tensor_extra_cl * extra_q = (ggml_tensor_extra_cl *)q->extra;
5616+
ggml_tensor_extra_cl * extra_k = (ggml_tensor_extra_cl *)k->extra;
5617+
ggml_tensor_extra_cl * extra_v = (ggml_tensor_extra_cl *)v->extra;
5618+
ggml_tensor_extra_cl * extra_o = (ggml_tensor_extra_cl *)dst->extra;
5619+
ggml_tensor_extra_cl * extra_mask = mask ? (ggml_tensor_extra_cl *)mask->extra : NULL;
5620+
5621+
cl_ulong offset_q = extra_q->offset + q->view_offs;
5622+
cl_ulong offset_k = extra_k->offset + k->view_offs;
5623+
cl_ulong offset_v = extra_v->offset + v->view_offs;
5624+
cl_ulong offset_o = extra_o->offset + dst->view_offs;
5625+
cl_mem mask_buffer = extra_mask ? extra_mask->data_device : NULL;
5626+
cl_ulong offset_mask = extra_mask ? extra_mask->offset + mask->view_offs : 0;
5627+
5628+
const cl_ulong q_nb1 = q->nb[1], q_nb2 = q->nb[2], q_nb3 = q->nb[3];
5629+
const cl_ulong k_nb1 = k->nb[1], k_nb2 = k->nb[2], k_nb3 = k->nb[3];
5630+
const cl_ulong v_nb1 = v->nb[1], v_nb2 = v->nb[2], v_nb3 = v->nb[3];
5631+
const cl_ulong o_nb1 = dst->nb[1], o_nb2 = dst->nb[2], o_nb3 = dst->nb[3];
5632+
const cl_ulong mask_nb1 = mask ? mask->nb[1] : 0;
5633+
const cl_ulong mask_nb2 = mask ? mask->nb[2] : 0;
5634+
const cl_ulong mask_nb3 = mask ? mask->nb[3] : 0;
5635+
const int mask_ne2 = mask ? mask->ne[2] : 0;
5636+
const int mask_ne3 = mask ? mask->ne[3] : 0;
5637+
5638+
float scale, max_bias, logit_softcap;
5639+
const float * params = (const float *)dst->op_params;
5640+
scale = params[0];
5641+
max_bias = params[1];
5642+
logit_softcap = params[2];
5643+
5644+
const int is_causal = (mask == NULL && n_q > 1 && n_q == n_kv);
5645+
5646+
const int n_head_log2_val = n_head > 0 ? 1u << (int)floorf(log2f((float)n_head)) : 0;
5647+
const float n_head_log2_f = n_head_log2_val > 0 ? (float)n_head_log2_val : 1.0f;
5648+
const float m0 = powf(2.0f, -(max_bias) / n_head_log2_f);
5649+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2_f);
5650+
5651+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_q->data_device));
5652+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset_q));
5653+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_k->data_device));
5654+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset_k));
5655+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra_v->data_device));
5656+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset_v));
5657+
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extra_o->data_device));
5658+
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offset_o));
5659+
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(float), &scale));
5660+
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &n_q));
5661+
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &n_kv));
5662+
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &is_causal));
5663+
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &n_head));
5664+
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &q_nb1)); CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &q_nb2)); CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &q_nb3));
5665+
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &k_nb1)); CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &k_nb2)); CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &k_nb3));
5666+
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &v_nb1)); CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &v_nb2)); CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &v_nb3));
5667+
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &o_nb1)); CL_CHECK(clSetKernelArg(kernel, 23, sizeof(cl_ulong), &o_nb2)); CL_CHECK(clSetKernelArg(kernel, 24, sizeof(cl_ulong), &o_nb3));
5668+
CL_CHECK(clSetKernelArg(kernel, 25, sizeof(float), &max_bias));
5669+
CL_CHECK(clSetKernelArg(kernel, 26, sizeof(float), &m0));
5670+
CL_CHECK(clSetKernelArg(kernel, 27, sizeof(float), &m1));
5671+
CL_CHECK(clSetKernelArg(kernel, 28, sizeof(int), &n_head_log2_val));
5672+
CL_CHECK(clSetKernelArg(kernel, 29, sizeof(float), &logit_softcap));
5673+
CL_CHECK(clSetKernelArg(kernel, 30, sizeof(int), &n_head_kv));
5674+
CL_CHECK(clSetKernelArg(kernel, 31, sizeof(cl_mem), &mask_buffer));
5675+
CL_CHECK(clSetKernelArg(kernel, 32, sizeof(cl_ulong), &offset_mask));
5676+
CL_CHECK(clSetKernelArg(kernel, 33, sizeof(cl_ulong), &mask_nb1));
5677+
CL_CHECK(clSetKernelArg(kernel, 34, sizeof(cl_ulong), &mask_nb2));
5678+
CL_CHECK(clSetKernelArg(kernel, 35, sizeof(cl_ulong), &mask_nb3));
5679+
CL_CHECK(clSetKernelArg(kernel, 36, sizeof(int), &mask_ne2));
5680+
CL_CHECK(clSetKernelArg(kernel, 37, sizeof(int), &mask_ne3));
5681+
5682+
if (n_q == 1) {
5683+
const size_t wg_size = 64;
5684+
size_t local_work_size[] = { wg_size, 1 };
5685+
size_t global_work_size[] = { wg_size, (size_t)(n_head * n_batch) };
5686+
backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);
5687+
} else {
5688+
const int block_m = backend_ctx->kernels_flash_attn_bm.at(dk_dv);
5689+
const size_t wg_size = block_m;
5690+
size_t local_work_size[] = { wg_size, 1 };
5691+
size_t global_work_size[] = { (size_t)((n_q + block_m - 1) / block_m) * wg_size, (size_t)(n_head * n_batch) };
5692+
backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);
5693+
}
5694+
}
5695+
54545696
static void ggml_cl_mul_mat_f16_f32_tiled(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
54555697
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
54565698

@@ -7607,6 +7849,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
76077849
}
76087850
func = ggml_cl_sum_rows;
76097851
break;
7852+
case GGML_OP_FLASH_ATTN_EXT:
7853+
if (!any_on_device) {
7854+
return false;
7855+
}
7856+
ggml_cl_flash_attn(backend, tensor->src[0], tensor->src[1], tensor);
7857+
return true;
76107858
default:
76117859
return false;
76127860
}

0 commit comments

Comments
 (0)