From c9429b72d12f1d03e52cada0aa66303ffd37da08 Mon Sep 17 00:00:00 2001 From: yehudit-dev Date: Mon, 3 Nov 2025 14:50:15 +0200 Subject: [PATCH 1/2] sycl: initialize flash-attention implementation Co-authored-by: safranowith Co-authored-by: ye-NX --- ggml/src/ggml-sycl/CMakeLists.txt | 9 ++ .../ggml-sycl/flash-attn/flash-attn-sycl.cpp | 98 ++++++++++++++++ .../ggml-sycl/flash-attn/flash-attn-sycl.h | 10 ++ .../flash-attn/kernels/flash-attn-kernel.h | 108 ++++++++++++++++++ ggml/src/ggml-sycl/ggml-sycl.cpp | 5 + 5 files changed, 230 insertions(+) create mode 100644 ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.cpp create mode 100644 ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.h create mode 100644 ggml/src/ggml-sycl/flash-attn/kernels/flash-attn-kernel.h diff --git a/ggml/src/ggml-sycl/CMakeLists.txt b/ggml/src/ggml-sycl/CMakeLists.txt index efd78b912..19f96607b 100644 --- a/ggml/src/ggml-sycl/CMakeLists.txt +++ b/ggml/src/ggml-sycl/CMakeLists.txt @@ -27,6 +27,15 @@ file(GLOB GGML_HEADERS_SYCL "*.hpp") file(GLOB GGML_SOURCES_SYCL "*.cpp") target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL} ${GGML_SOURCES_SYCL}) +# Include flash-attn sources (SYCL optimized flash attention implementation) +file(GLOB GGML_HEADERS_SYCL_FLASH "flash-attn/*.h" "flash-attn/*.hpp") +file(GLOB GGML_SOURCES_SYCL_FLASH "flash-attn/*.cpp" "flash-attn/*.c") +target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL_FLASH} ${GGML_SOURCES_SYCL_FLASH}) + +# Also include kernel headers under flash-attn/kernels +file(GLOB GGML_HEADERS_SYCL_FLASH_KERNELS "flash-attn/kernels/*.h" "flash-attn/kernels/*.hpp") +target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL_FLASH_KERNELS}) + if (WIN32) # To generate a Visual Studio solution, using Intel C++ Compiler for ggml-sycl is mandatory if( ${CMAKE_GENERATOR} MATCHES "Visual Studio" AND NOT (${CMAKE_GENERATOR_TOOLSET} MATCHES "Intel C")) diff --git a/ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.cpp b/ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.cpp new file mode 100644 index 000000000..1dbc4b952 --- /dev/null +++ b/ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.cpp @@ -0,0 +1,98 @@ +#include "flash-attn-sycl.h" + +#include "kernels/flash-attn-kernel.h" + +#include +#include +#include +#include + +#define FLASH_ATTN_BR_MAX 32 +#define FLASH_ATTN_BC_MAX 32 + +// Flash Attention: https://arxiv.org/abs/2205.14135 +void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + const ggml_tensor * mask = dst->src[3]; + + GGML_ASSERT(Q != nullptr); + GGML_ASSERT(K != nullptr); + GGML_ASSERT(V != nullptr); + GGML_ASSERT(dst != nullptr); + + if (Q->type != GGML_TYPE_F32 || K->type != GGML_TYPE_F32 || V->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { + fprintf(stderr, "[SYCL] FLASH-ATTENTION: tensor type not supported (Q=%d, K=%d, V=%d, dst=%d)\n", Q->type, K->type, V->type, dst->type); + return; + } + + const float * Q_d = (const float *) Q->data; + const float * K_d = (const float *) K->data; + const float * V_d = (const float *) V->data; + float * dst_d = (float *) dst->data; + + dpct::queue_ptr stream = ctx.stream(); + + const int64_t d = Q->ne[0]; + const int64_t N = Q->ne[1]; + + float scale; + float max_bias; + float logit_softcap; + + std::memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); + std::memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); + std::memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float)); + + const bool masked = (mask != nullptr); + + const int Br = std::min((int) FLASH_ATTN_BR_MAX, (int) N); + const int Bc = std::min((int) FLASH_ATTN_BC_MAX, (int) N); + + const int Tr = (N + Br - 1) / Br; + const int Tc = (N + Bc - 1) / Bc; + + float * l_d = (float *) sycl::malloc_device(N * sizeof(float), *stream); + float * m_d = (float *) sycl::malloc_device(N * sizeof(float), *stream); + + stream->fill(l_d, 0.0f, N); + stream->fill(m_d, -std::numeric_limits::infinity(), N); + stream->fill(dst_d, 0.0f, N * d); + stream->wait(); + + for (int j = 0; j < Tc; ++j) { + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::range<1>(Tr), [=](sycl::id<1> idx) { + const int i = idx[0]; + flash_attn_tiled_kernel(Q_d, K_d, V_d, dst_d, l_d, m_d, i, j, Br, + Bc, N, d, masked, scale); + }); + }); + } + + stream->wait(); + + sycl::free(l_d, *stream); + sycl::free(m_d, *stream); +} + +bool ggml_sycl_flash_attn_ext_supported(const ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + + if (Q == nullptr || K == nullptr || V == nullptr) { + return false; + } + + if (Q->type != GGML_TYPE_F32 || K->type != GGML_TYPE_F32 || V->type != GGML_TYPE_F32) { + return false; + } + + if (dst->type != GGML_TYPE_F32) { + return false; + } + + return true; +} diff --git a/ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.h b/ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.h new file mode 100644 index 000000000..c50d09aa0 --- /dev/null +++ b/ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.h @@ -0,0 +1,10 @@ +#pragma once + +#include "../common.hpp" + +// Flash attention operation for SYCL backend +// This implements the Flash Attention algorithm optimized for SYCL devices +void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +// Check if flash attention is supported for given tensor +bool ggml_sycl_flash_attn_ext_supported(const ggml_tensor * dst); diff --git a/ggml/src/ggml-sycl/flash-attn/kernels/flash-attn-kernel.h b/ggml/src/ggml-sycl/flash-attn/kernels/flash-attn-kernel.h new file mode 100644 index 000000000..721007eab --- /dev/null +++ b/ggml/src/ggml-sycl/flash-attn/kernels/flash-attn-kernel.h @@ -0,0 +1,108 @@ +#pragma once + +#include + +template +inline void flash_attn_tiled_kernel(const float * Q, + const float * K, + const float * V, + float * O, + float * l, + float * m, + const int i_block, + const int j_block, + const int Br, + const int Bc, + const int N, + const int d, + const bool masked, + const float scale) { + const int i_start = i_block * Br; + const int j_start = j_block * Bc; + + float S[Br_MAX][Bc_MAX]; + float P[Br_MAX][Bc_MAX]; + float m_local[Br_MAX]; + float l_local[Br_MAX]; + + for (int qi = 0; qi < Br; ++qi) { + const int q_row = i_start + qi; + if (q_row >= N) { + continue; + } + + for (int kj = 0; kj < Bc; ++kj) { + const int k_row = j_start + kj; + if (k_row >= N) { + S[qi][kj] = -INFINITY; + continue; + } + + if (masked && k_row > q_row) { + S[qi][kj] = -INFINITY; + continue; + } + + float score = 0.0f; + for (int k = 0; k < d; ++k) { + score += Q[q_row * d + k] * K[k_row * d + k]; + } + S[qi][kj] = score * scale; + } + } + + for (int qi = 0; qi < Br; ++qi) { + const int q_row = i_start + qi; + if (q_row >= N) { + continue; + } + + m_local[qi] = -INFINITY; + for (int kj = 0; kj < Bc; ++kj) { + if (j_start + kj < N) { + m_local[qi] = sycl::fmax(m_local[qi], S[qi][kj]); + } + } + + l_local[qi] = 0.0f; + for (int kj = 0; kj < Bc; ++kj) { + if (j_start + kj < N && !sycl::isinf(S[qi][kj])) { + P[qi][kj] = sycl::exp(S[qi][kj] - m_local[qi]); + l_local[qi] += P[qi][kj]; + } else { + P[qi][kj] = 0.0f; + } + } + } + + for (int qi = 0; qi < Br; ++qi) { + const int q_row = i_start + qi; + if (q_row >= N) { + continue; + } + + const float m_old = m[q_row]; + const float m_new = sycl::fmax(m_old, m_local[qi]); + const float l_old = l[q_row]; + const float l_new = sycl::exp(m_old - m_new) * l_old + sycl::exp(m_local[qi] - m_new) * l_local[qi]; + + const float correction_old = sycl::exp(m_old - m_new); + const float correction_new = sycl::exp(m_local[qi] - m_new); + + for (int k = 0; k < d; ++k) { + float pv = 0.0f; + for (int kj = 0; kj < Bc; ++kj) { + const int v_row = j_start + kj; + if (v_row < N) { + pv += P[qi][kj] * V[v_row * d + k]; + } + } + + const int o_idx = q_row * d + k; + O[o_idx] = (correction_old * O[o_idx] + correction_new * pv) / l_new; + } + + l[q_row] = l_new; + m[q_row] = m_new; + } +} diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 33f903507..a37e89d4e 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -3839,6 +3839,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_OP_ARGSORT: ggml_sycl_argsort(ctx, dst); break; + case GGML_OP_FLASH_ATTN_EXT: + ggml_sycl_op_flash_attn(ctx, dst); + break; case GGML_OP_TIMESTEP_EMBEDDING: ggml_sycl_op_timestep_embedding(ctx, dst); break; @@ -4501,6 +4504,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_MEAN: case GGML_OP_ARGSORT: return ggml_is_contiguous(op->src[0]); + case GGML_OP_FLASH_ATTN_EXT: + return ggml_sycl_flash_attn_ext_supported(op); case GGML_OP_POOL_2D: case GGML_OP_ACC: return true; From c62b98b08374cbbc22f8b8a04d105d50d6ec4ad9 Mon Sep 17 00:00:00 2001 From: yehudit-dev Date: Sun, 23 Nov 2025 16:59:25 +0200 Subject: [PATCH 2/2] flash-attn sycl: apply fixes and remove old implementation Co-authored-by: safranowith Co-authored-by: ye-NX --- ggml/src/ggml-sycl/CMakeLists.txt | 6 +- ggml/src/ggml-sycl/backend.hpp | 1 + ggml/src/ggml-sycl/fattn.cpp | 212 ++++++++++++++++++ .../flash-attn-sycl.h => fattn.hpp} | 7 +- ggml/src/ggml-sycl/fattn_common.hpp | 12 + ggml/src/ggml-sycl/fattn_kernel.hpp | 141 ++++++++++++ .../ggml-sycl/flash-attn/flash-attn-sycl.cpp | 98 -------- .../flash-attn/kernels/flash-attn-kernel.h | 108 --------- 8 files changed, 374 insertions(+), 211 deletions(-) create mode 100644 ggml/src/ggml-sycl/fattn.cpp rename ggml/src/ggml-sycl/{flash-attn/flash-attn-sycl.h => fattn.hpp} (75%) create mode 100644 ggml/src/ggml-sycl/fattn_common.hpp create mode 100644 ggml/src/ggml-sycl/fattn_kernel.hpp delete mode 100644 ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.cpp delete mode 100644 ggml/src/ggml-sycl/flash-attn/kernels/flash-attn-kernel.h diff --git a/ggml/src/ggml-sycl/CMakeLists.txt b/ggml/src/ggml-sycl/CMakeLists.txt index 19f96607b..fa460273d 100644 --- a/ggml/src/ggml-sycl/CMakeLists.txt +++ b/ggml/src/ggml-sycl/CMakeLists.txt @@ -28,12 +28,12 @@ file(GLOB GGML_SOURCES_SYCL "*.cpp") target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL} ${GGML_SOURCES_SYCL}) # Include flash-attn sources (SYCL optimized flash attention implementation) -file(GLOB GGML_HEADERS_SYCL_FLASH "flash-attn/*.h" "flash-attn/*.hpp") -file(GLOB GGML_SOURCES_SYCL_FLASH "flash-attn/*.cpp" "flash-attn/*.c") +file(GLOB GGML_HEADERS_SYCL_FLASH "fattn*.h" "fattn*.hpp") +file(GLOB GGML_SOURCES_SYCL_FLASH "fattn*.cpp" "fattn*.c") target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL_FLASH} ${GGML_SOURCES_SYCL_FLASH}) # Also include kernel headers under flash-attn/kernels -file(GLOB GGML_HEADERS_SYCL_FLASH_KERNELS "flash-attn/kernels/*.h" "flash-attn/kernels/*.hpp") +file(GLOB GGML_HEADERS_SYCL_FLASH_KERNELS "fattn_kernel*.h" "fattn_kernel*.hpp") target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL_FLASH_KERNELS}) if (WIN32) diff --git a/ggml/src/ggml-sycl/backend.hpp b/ggml/src/ggml-sycl/backend.hpp index b1575b814..c9c3d5061 100644 --- a/ggml/src/ggml-sycl/backend.hpp +++ b/ggml/src/ggml-sycl/backend.hpp @@ -38,6 +38,7 @@ #include "tsembd.hpp" #include "wkv.hpp" #include "pad_reflect_1d.hpp" +#include "fattn.hpp" #endif // GGML_SYCL_BACKEND_HPP diff --git a/ggml/src/ggml-sycl/fattn.cpp b/ggml/src/ggml-sycl/fattn.cpp new file mode 100644 index 000000000..67cafa10c --- /dev/null +++ b/ggml/src/ggml-sycl/fattn.cpp @@ -0,0 +1,212 @@ +#include "./fattn.hpp" +#include "./fattn_kernel.hpp" +#include "./fattn_common.hpp" + +#include +#include +#include +#include + +#define Br 32 +#define Bc 32 + + +bool ggml_sycl_flash_attn_ext_supported(const ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + + if (Q == nullptr || K == nullptr || V == nullptr) { + return false; + } + if (Q->type == GGML_TYPE_F32 && K->type == GGML_TYPE_F32 && V->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return true; + } + // if (Q->type == GGML_TYPE_F16 && K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + // return true; + // } + + return false; +} + +template +void ggml_sycl_op_flash_attn_2(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + + GGML_ASSERT(Q != nullptr); + GGML_ASSERT(K != nullptr); + GGML_ASSERT(V != nullptr); + GGML_ASSERT(dst != nullptr); + + //not support KV_Cache yet + GGML_ASSERT(K->ne[1] == V->ne[1]); + + //not support multi head and gqa yet + GGML_ASSERT(Q->ne[2] == 1); + GGML_ASSERT(K->ne[2] == 1); + GGML_ASSERT(V->ne[2] == 1); + + const float * Q_d = (const float *) Q->data; + const float * K_d = (const float *) K->data; + const float * V_d = (const float *) V->data; + float * dst_d = (float *) dst->data; + + dpct::queue_ptr stream = ctx.stream(); + + const int64_t N = Q->ne[1]; + + const ptrdiff_t q_row_stride = Q->nb[1] / (ptrdiff_t)sizeof(float); + const ptrdiff_t k_row_stride = K->nb[1] / (ptrdiff_t)sizeof(float); + const ptrdiff_t v_row_stride = V->nb[1] / (ptrdiff_t)sizeof(float); + const ptrdiff_t o_row_stride = dst->nb[1] / (ptrdiff_t)sizeof(float); + + // const int Br = std::min((int) FLASH_ATTN_BR_MAX, (int) N); + // const int Bc = std::min((int) FLASH_ATTN_BC_MAX, (int) N); + + const int Tr = (N + Br - 1) / Br; + const int Tc = (N + Bc - 1) / Bc; + + float * l_d = (float *) sycl::malloc_device(N * sizeof(float), *stream); + float * m_d = (float *) sycl::malloc_device(N * sizeof(float), *stream); + + sycl::range<2> global(Br * Tr, Tc); + sycl::range<2> local(Br,1); + + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor Qtile({Br, DQK}, cgh); + sycl::local_accessor Ktile({Bc, DQK}, cgh); + sycl::local_accessor Vtile({Bc, DV}, cgh); + sycl::local_accessor Stile({Br, Bc}, cgh); + sycl::local_accessor Ptile({Br * Bc}, cgh); + sycl::local_accessor m_local({Br}, cgh); + sycl::local_accessor l_local({Br}, cgh); + + float* q_loc = Qtile.template get_multi_ptr().get(); + float* k_loc = Ktile.template get_multi_ptr().get(); + float* v_loc = Vtile.template get_multi_ptr().get(); + float* s_loc = Stile.template get_multi_ptr().get(); + float* p_loc = Ptile.template get_multi_ptr().get(); + float* m_loc = m_local.template get_multi_ptr().get(); + float* l_loc = l_local.template get_multi_ptr().get(); + + cgh.parallel_for(sycl::nd_range<2>(global, local), [=](sycl::nd_item<2> it) { + auto group = it.get_group(); + int group_id_i = group.get_group_id(0); + int group_id_j = group.get_group_id(1); + + + int row0 = group_id_i * Br; + int col0 = group_id_j * Bc; + + if (row0 >= (int) N || col0 >= (int) N) { + return; + } + + const float* Q_block = Q_d + (ptrdiff_t)row0 * q_row_stride; + const float* K_block = K_d + (ptrdiff_t)col0 * k_row_stride; + const float* V_block = V_d + (ptrdiff_t)col0 * v_row_stride; + float* O_block = dst_d + (ptrdiff_t)row0 * o_row_stride; + + //this lines does not support non-contiguous tensors + ggml_sycl_memcpy
(q_loc, Q_block); + ggml_sycl_memcpy(k_loc, K_block); + ggml_sycl_memcpy(v_loc, V_block); + + it.barrier(sycl::access::fence_space::local_space); + + flash_attn_mul_mat_QK_kernel( + it, + Q_block, q_row_stride, + K_block, k_row_stride, + s_loc, (ptrdiff_t)Bc, + Br, Bc + ); + + it.barrier(sycl::access::fence_space::local_space); + + flash_attn_softmax_kernel( + it, + s_loc, p_loc, + m_loc, l_loc, + Br, Bc, + l_d, m_d + ); + + it.barrier(sycl::access::fence_space::local_space); + + flash_attn_mul_mat_PV_kernel( + it, + p_loc, (ptrdiff_t)Bc, + V_block, v_row_stride, + O_block, o_row_stride, + Br,Bc + ); + + it.barrier(sycl::access::fence_space::local_space); + }); + }); + + + stream->submit([&](sycl::handler& cgh) { + const ptrdiff_t o_stride = o_row_stride; + + cgh.parallel_for(sycl::range<1>(N), [=](sycl::id<1> id_row) { + int row = id_row[0]; + float l_val = l_d[row]; + + if (l_val <= 0.0f) { + return; + } + + float inv_l = 1.0f / l_val; + float * o_row = dst_d + (ptrdiff_t)row * o_stride; + + for (int col = 0; col < DV; ++col) { + o_row[col] *= inv_l; + } + }); + }); + + + sycl::free(l_d, *stream); + sycl::free(m_d, *stream); +} + + +void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * V = dst->src[2]; + + switch (Q->ne[0]) { + case 64: + GGML_ASSERT(V->ne[0] == 64); + ggml_sycl_op_flash_attn_2< 64, 64>(ctx, dst); + break; + case 80: + GGML_ASSERT(V->ne[0] == 80); + ggml_sycl_op_flash_attn_2< 80, 80>(ctx, dst); + break; + case 96: + GGML_ASSERT(V->ne[0] == 96); + ggml_sycl_op_flash_attn_2< 96, 96>(ctx, dst); + break; + case 112: + GGML_ASSERT(V->ne[0] == 112); + ggml_sycl_op_flash_attn_2<112, 112>(ctx, dst); + break; + case 128: + GGML_ASSERT(V->ne[0] == 128); + ggml_sycl_op_flash_attn_2<128, 128>(ctx, dst); + break; + case 256: + GGML_ASSERT(V->ne[0] == 256); + ggml_sycl_op_flash_attn_2<256, 256>(ctx, dst); + break; + default: + GGML_ABORT("fatal error"); + break; + } +} + diff --git a/ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.h b/ggml/src/ggml-sycl/fattn.hpp similarity index 75% rename from ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.h rename to ggml/src/ggml-sycl/fattn.hpp index c50d09aa0..94188e18e 100644 --- a/ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.h +++ b/ggml/src/ggml-sycl/fattn.hpp @@ -1,6 +1,7 @@ -#pragma once +#ifndef GGML_SYCL_FATTN_HPP +#define GGML_SYCL_FATTN_HPP -#include "../common.hpp" +#include "common.hpp" // Flash attention operation for SYCL backend // This implements the Flash Attention algorithm optimized for SYCL devices @@ -8,3 +9,5 @@ void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) // Check if flash attention is supported for given tensor bool ggml_sycl_flash_attn_ext_supported(const ggml_tensor * dst); + +#endif // GGML_SYCL_FATTN_HPP diff --git a/ggml/src/ggml-sycl/fattn_common.hpp b/ggml/src/ggml-sycl/fattn_common.hpp new file mode 100644 index 000000000..53a99f44f --- /dev/null +++ b/ggml/src/ggml-sycl/fattn_common.hpp @@ -0,0 +1,12 @@ +#ifndef GGML_SYCL_FATTN_COMMON_HPP +#define GGML_SYCL_FATTN_COMMON_HPP + +template +inline void ggml_sycl_memcpy(float* dst, const float* src) { + #pragma unroll + for (int i = 0; i < N; ++i) { + dst[i] = src[i]; + } +} + +#endif // GGML_SYCL_FATTN_COMMON_HPP diff --git a/ggml/src/ggml-sycl/fattn_kernel.hpp b/ggml/src/ggml-sycl/fattn_kernel.hpp new file mode 100644 index 000000000..822947056 --- /dev/null +++ b/ggml/src/ggml-sycl/fattn_kernel.hpp @@ -0,0 +1,141 @@ +#ifndef GGML_SYCL_FATTN_KERNEL_HPP +#define GGML_SYCL_FATTN_KERNEL_HPP + +#include + +template +inline void flash_attn_mul_mat_QK_kernel( + sycl::nd_item<2> it, + const float * Q, ptrdiff_t q_row_stride, + const float * K, ptrdiff_t k_row_stride, + float * S, ptrdiff_t s_row_stride, + const int Br, const int Bc) { + + const int i = it.get_local_id(0); + if (i >= Br) { + return; + } + + const float * q_vec = Q + i * q_row_stride; + float * s_row = S + i * s_row_stride; + + for (int j = 0; j < Bc; ++j) { + const float * k_vec = K + j * k_row_stride; + float score = 0.0f; + +#pragma unroll + for (int k = 0; k < QKD; ++k) { + score += q_vec[k] * k_vec[k]; + } + + s_row[j] = score; + } +} + + +inline void flash_attn_softmax_kernel( + sycl::nd_item<2> it, + float * S, float * P, + float * m_local, float * l_local, + const int Br, const int Bc, + float * l_d, float * m_d +) { + const int li = it.get_local_id(0); + const int gi = it.get_group(0); + const int row = gi * Br + li; + + if (li >= Br) { + return; + } + + const int row_offset = li * Bc; + + float m_old = m_d[row]; + float l_old = l_d[row]; + + // 2. Block max + float m_block = -INFINITY; + for (int j = 0; j < Bc; ++j) { + const float s_ij = S[row_offset + j]; + m_block = sycl::fmax(m_block, s_ij); + } + + // 3. Block exp-sum + float l_block = 0.0f; + for (int j = 0; j < Bc; ++j) { + const float e = sycl::exp(S[row_offset + j] - m_block); + P[row_offset + j] = e; // temporary store + l_block += e; + } + + // 4. Merge block stats with global (streaming softmax) + float m_new; + float l_new; + + if (l_old == 0.0f && m_old == -INFINITY) { + // first block for this row + m_new = m_block; + l_new = l_block; + } else { + m_new = sycl::fmax(m_old, m_block); + + const float alpha = sycl::exp(m_old - m_new); + const float beta = sycl::exp(m_block - m_new); + + l_new = alpha * l_old + beta * l_block; + } + + // 5. Store updated global stats + m_d[row] = m_new; + l_d[row] = l_new; + + // 6. Convert local e_ij to global probabilities p_ij + float scale_block = 0.0f; + if (l_new > 0.0f) { + scale_block = sycl::exp(m_block - m_new) / l_new; + } + + for (int j = 0; j < Bc; ++j) { + P[row_offset + j] *= scale_block; + } + + // 7. Optional: keep local copies + m_local[li] = m_new; + l_local[li] = l_new; +} + + + + +template +inline void flash_attn_mul_mat_PV_kernel( + sycl::nd_item<2> it, + const float * P, ptrdiff_t p_row_stride, + const float * V, ptrdiff_t v_row_stride, + float * O, ptrdiff_t o_row_stride, + const int Br, const int Bc) { + + const int i = it.get_local_id(0); + if (i >= Br) { + return; + } + + const float * p_row = P + i * p_row_stride; + float * o_row = O + i * o_row_stride; + + for (int j = 0; j < VD; ++j) { + float acc = 0.0f; + +#pragma unroll + for (int k = 0; k < Bc; ++k) { + const float * v_row = V + k * v_row_stride; + acc += p_row[k] * v_row[j]; + } + + o_row[j] = acc; + } +} + +#endif // GGML_SYCL_FATTN_KERNEL_HPP + + diff --git a/ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.cpp b/ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.cpp deleted file mode 100644 index 1dbc4b952..000000000 --- a/ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.cpp +++ /dev/null @@ -1,98 +0,0 @@ -#include "flash-attn-sycl.h" - -#include "kernels/flash-attn-kernel.h" - -#include -#include -#include -#include - -#define FLASH_ATTN_BR_MAX 32 -#define FLASH_ATTN_BC_MAX 32 - -// Flash Attention: https://arxiv.org/abs/2205.14135 -void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - const ggml_tensor * Q = dst->src[0]; - const ggml_tensor * K = dst->src[1]; - const ggml_tensor * V = dst->src[2]; - const ggml_tensor * mask = dst->src[3]; - - GGML_ASSERT(Q != nullptr); - GGML_ASSERT(K != nullptr); - GGML_ASSERT(V != nullptr); - GGML_ASSERT(dst != nullptr); - - if (Q->type != GGML_TYPE_F32 || K->type != GGML_TYPE_F32 || V->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { - fprintf(stderr, "[SYCL] FLASH-ATTENTION: tensor type not supported (Q=%d, K=%d, V=%d, dst=%d)\n", Q->type, K->type, V->type, dst->type); - return; - } - - const float * Q_d = (const float *) Q->data; - const float * K_d = (const float *) K->data; - const float * V_d = (const float *) V->data; - float * dst_d = (float *) dst->data; - - dpct::queue_ptr stream = ctx.stream(); - - const int64_t d = Q->ne[0]; - const int64_t N = Q->ne[1]; - - float scale; - float max_bias; - float logit_softcap; - - std::memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); - std::memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); - std::memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float)); - - const bool masked = (mask != nullptr); - - const int Br = std::min((int) FLASH_ATTN_BR_MAX, (int) N); - const int Bc = std::min((int) FLASH_ATTN_BC_MAX, (int) N); - - const int Tr = (N + Br - 1) / Br; - const int Tc = (N + Bc - 1) / Bc; - - float * l_d = (float *) sycl::malloc_device(N * sizeof(float), *stream); - float * m_d = (float *) sycl::malloc_device(N * sizeof(float), *stream); - - stream->fill(l_d, 0.0f, N); - stream->fill(m_d, -std::numeric_limits::infinity(), N); - stream->fill(dst_d, 0.0f, N * d); - stream->wait(); - - for (int j = 0; j < Tc; ++j) { - stream->submit([&](sycl::handler & cgh) { - cgh.parallel_for(sycl::range<1>(Tr), [=](sycl::id<1> idx) { - const int i = idx[0]; - flash_attn_tiled_kernel(Q_d, K_d, V_d, dst_d, l_d, m_d, i, j, Br, - Bc, N, d, masked, scale); - }); - }); - } - - stream->wait(); - - sycl::free(l_d, *stream); - sycl::free(m_d, *stream); -} - -bool ggml_sycl_flash_attn_ext_supported(const ggml_tensor * dst) { - const ggml_tensor * Q = dst->src[0]; - const ggml_tensor * K = dst->src[1]; - const ggml_tensor * V = dst->src[2]; - - if (Q == nullptr || K == nullptr || V == nullptr) { - return false; - } - - if (Q->type != GGML_TYPE_F32 || K->type != GGML_TYPE_F32 || V->type != GGML_TYPE_F32) { - return false; - } - - if (dst->type != GGML_TYPE_F32) { - return false; - } - - return true; -} diff --git a/ggml/src/ggml-sycl/flash-attn/kernels/flash-attn-kernel.h b/ggml/src/ggml-sycl/flash-attn/kernels/flash-attn-kernel.h deleted file mode 100644 index 721007eab..000000000 --- a/ggml/src/ggml-sycl/flash-attn/kernels/flash-attn-kernel.h +++ /dev/null @@ -1,108 +0,0 @@ -#pragma once - -#include - -template -inline void flash_attn_tiled_kernel(const float * Q, - const float * K, - const float * V, - float * O, - float * l, - float * m, - const int i_block, - const int j_block, - const int Br, - const int Bc, - const int N, - const int d, - const bool masked, - const float scale) { - const int i_start = i_block * Br; - const int j_start = j_block * Bc; - - float S[Br_MAX][Bc_MAX]; - float P[Br_MAX][Bc_MAX]; - float m_local[Br_MAX]; - float l_local[Br_MAX]; - - for (int qi = 0; qi < Br; ++qi) { - const int q_row = i_start + qi; - if (q_row >= N) { - continue; - } - - for (int kj = 0; kj < Bc; ++kj) { - const int k_row = j_start + kj; - if (k_row >= N) { - S[qi][kj] = -INFINITY; - continue; - } - - if (masked && k_row > q_row) { - S[qi][kj] = -INFINITY; - continue; - } - - float score = 0.0f; - for (int k = 0; k < d; ++k) { - score += Q[q_row * d + k] * K[k_row * d + k]; - } - S[qi][kj] = score * scale; - } - } - - for (int qi = 0; qi < Br; ++qi) { - const int q_row = i_start + qi; - if (q_row >= N) { - continue; - } - - m_local[qi] = -INFINITY; - for (int kj = 0; kj < Bc; ++kj) { - if (j_start + kj < N) { - m_local[qi] = sycl::fmax(m_local[qi], S[qi][kj]); - } - } - - l_local[qi] = 0.0f; - for (int kj = 0; kj < Bc; ++kj) { - if (j_start + kj < N && !sycl::isinf(S[qi][kj])) { - P[qi][kj] = sycl::exp(S[qi][kj] - m_local[qi]); - l_local[qi] += P[qi][kj]; - } else { - P[qi][kj] = 0.0f; - } - } - } - - for (int qi = 0; qi < Br; ++qi) { - const int q_row = i_start + qi; - if (q_row >= N) { - continue; - } - - const float m_old = m[q_row]; - const float m_new = sycl::fmax(m_old, m_local[qi]); - const float l_old = l[q_row]; - const float l_new = sycl::exp(m_old - m_new) * l_old + sycl::exp(m_local[qi] - m_new) * l_local[qi]; - - const float correction_old = sycl::exp(m_old - m_new); - const float correction_new = sycl::exp(m_local[qi] - m_new); - - for (int k = 0; k < d; ++k) { - float pv = 0.0f; - for (int kj = 0; kj < Bc; ++kj) { - const int v_row = j_start + kj; - if (v_row < N) { - pv += P[qi][kj] * V[v_row * d + k]; - } - } - - const int o_idx = q_row * d + k; - O[o_idx] = (correction_old * O[o_idx] + correction_new * pv) / l_new; - } - - l[q_row] = l_new; - m[q_row] = m_new; - } -}