diff --git a/ggml/src/ggml-sycl/CMakeLists.txt b/ggml/src/ggml-sycl/CMakeLists.txt index efd78b912..fa460273d 100644 --- a/ggml/src/ggml-sycl/CMakeLists.txt +++ b/ggml/src/ggml-sycl/CMakeLists.txt @@ -27,6 +27,15 @@ file(GLOB GGML_HEADERS_SYCL "*.hpp") file(GLOB GGML_SOURCES_SYCL "*.cpp") target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL} ${GGML_SOURCES_SYCL}) +# Include flash-attn sources (SYCL optimized flash attention implementation) +file(GLOB GGML_HEADERS_SYCL_FLASH "fattn*.h" "fattn*.hpp") +file(GLOB GGML_SOURCES_SYCL_FLASH "fattn*.cpp" "fattn*.c") +target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL_FLASH} ${GGML_SOURCES_SYCL_FLASH}) + +# Also include kernel headers under flash-attn/kernels +file(GLOB GGML_HEADERS_SYCL_FLASH_KERNELS "fattn_kernel*.h" "fattn_kernel*.hpp") +target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL_FLASH_KERNELS}) + if (WIN32) # To generate a Visual Studio solution, using Intel C++ Compiler for ggml-sycl is mandatory if( ${CMAKE_GENERATOR} MATCHES "Visual Studio" AND NOT (${CMAKE_GENERATOR_TOOLSET} MATCHES "Intel C")) diff --git a/ggml/src/ggml-sycl/backend.hpp b/ggml/src/ggml-sycl/backend.hpp index b1575b814..c9c3d5061 100644 --- a/ggml/src/ggml-sycl/backend.hpp +++ b/ggml/src/ggml-sycl/backend.hpp @@ -38,6 +38,7 @@ #include "tsembd.hpp" #include "wkv.hpp" #include "pad_reflect_1d.hpp" +#include "fattn.hpp" #endif // GGML_SYCL_BACKEND_HPP diff --git a/ggml/src/ggml-sycl/fattn.cpp b/ggml/src/ggml-sycl/fattn.cpp new file mode 100644 index 000000000..67cafa10c --- /dev/null +++ b/ggml/src/ggml-sycl/fattn.cpp @@ -0,0 +1,212 @@ +#include "./fattn.hpp" +#include "./fattn_kernel.hpp" +#include "./fattn_common.hpp" + +#include +#include +#include +#include + +#define Br 32 +#define Bc 32 + + +bool ggml_sycl_flash_attn_ext_supported(const ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + + if (Q == nullptr || K == nullptr || V == nullptr) { + return false; + } + if (Q->type == GGML_TYPE_F32 && K->type == GGML_TYPE_F32 && V->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return true; + } + // if (Q->type == GGML_TYPE_F16 && K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + // return true; + // } + + return false; +} + +template +void ggml_sycl_op_flash_attn_2(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + + GGML_ASSERT(Q != nullptr); + GGML_ASSERT(K != nullptr); + GGML_ASSERT(V != nullptr); + GGML_ASSERT(dst != nullptr); + + //not support KV_Cache yet + GGML_ASSERT(K->ne[1] == V->ne[1]); + + //not support multi head and gqa yet + GGML_ASSERT(Q->ne[2] == 1); + GGML_ASSERT(K->ne[2] == 1); + GGML_ASSERT(V->ne[2] == 1); + + const float * Q_d = (const float *) Q->data; + const float * K_d = (const float *) K->data; + const float * V_d = (const float *) V->data; + float * dst_d = (float *) dst->data; + + dpct::queue_ptr stream = ctx.stream(); + + const int64_t N = Q->ne[1]; + + const ptrdiff_t q_row_stride = Q->nb[1] / (ptrdiff_t)sizeof(float); + const ptrdiff_t k_row_stride = K->nb[1] / (ptrdiff_t)sizeof(float); + const ptrdiff_t v_row_stride = V->nb[1] / (ptrdiff_t)sizeof(float); + const ptrdiff_t o_row_stride = dst->nb[1] / (ptrdiff_t)sizeof(float); + + // const int Br = std::min((int) FLASH_ATTN_BR_MAX, (int) N); + // const int Bc = std::min((int) FLASH_ATTN_BC_MAX, (int) N); + + const int Tr = (N + Br - 1) / Br; + const int Tc = (N + Bc - 1) / Bc; + + float * l_d = (float *) sycl::malloc_device(N * sizeof(float), *stream); + float * m_d = (float *) sycl::malloc_device(N * sizeof(float), *stream); + + sycl::range<2> global(Br * Tr, Tc); + sycl::range<2> local(Br,1); + + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor Qtile({Br, DQK}, cgh); + sycl::local_accessor Ktile({Bc, DQK}, cgh); + sycl::local_accessor Vtile({Bc, DV}, cgh); + sycl::local_accessor Stile({Br, Bc}, cgh); + sycl::local_accessor Ptile({Br * Bc}, cgh); + sycl::local_accessor m_local({Br}, cgh); + sycl::local_accessor l_local({Br}, cgh); + + float* q_loc = Qtile.template get_multi_ptr().get(); + float* k_loc = Ktile.template get_multi_ptr().get(); + float* v_loc = Vtile.template get_multi_ptr().get(); + float* s_loc = Stile.template get_multi_ptr().get(); + float* p_loc = Ptile.template get_multi_ptr().get(); + float* m_loc = m_local.template get_multi_ptr().get(); + float* l_loc = l_local.template get_multi_ptr().get(); + + cgh.parallel_for(sycl::nd_range<2>(global, local), [=](sycl::nd_item<2> it) { + auto group = it.get_group(); + int group_id_i = group.get_group_id(0); + int group_id_j = group.get_group_id(1); + + + int row0 = group_id_i * Br; + int col0 = group_id_j * Bc; + + if (row0 >= (int) N || col0 >= (int) N) { + return; + } + + const float* Q_block = Q_d + (ptrdiff_t)row0 * q_row_stride; + const float* K_block = K_d + (ptrdiff_t)col0 * k_row_stride; + const float* V_block = V_d + (ptrdiff_t)col0 * v_row_stride; + float* O_block = dst_d + (ptrdiff_t)row0 * o_row_stride; + + //this lines does not support non-contiguous tensors + ggml_sycl_memcpy
(q_loc, Q_block); + ggml_sycl_memcpy(k_loc, K_block); + ggml_sycl_memcpy(v_loc, V_block); + + it.barrier(sycl::access::fence_space::local_space); + + flash_attn_mul_mat_QK_kernel( + it, + Q_block, q_row_stride, + K_block, k_row_stride, + s_loc, (ptrdiff_t)Bc, + Br, Bc + ); + + it.barrier(sycl::access::fence_space::local_space); + + flash_attn_softmax_kernel( + it, + s_loc, p_loc, + m_loc, l_loc, + Br, Bc, + l_d, m_d + ); + + it.barrier(sycl::access::fence_space::local_space); + + flash_attn_mul_mat_PV_kernel( + it, + p_loc, (ptrdiff_t)Bc, + V_block, v_row_stride, + O_block, o_row_stride, + Br,Bc + ); + + it.barrier(sycl::access::fence_space::local_space); + }); + }); + + + stream->submit([&](sycl::handler& cgh) { + const ptrdiff_t o_stride = o_row_stride; + + cgh.parallel_for(sycl::range<1>(N), [=](sycl::id<1> id_row) { + int row = id_row[0]; + float l_val = l_d[row]; + + if (l_val <= 0.0f) { + return; + } + + float inv_l = 1.0f / l_val; + float * o_row = dst_d + (ptrdiff_t)row * o_stride; + + for (int col = 0; col < DV; ++col) { + o_row[col] *= inv_l; + } + }); + }); + + + sycl::free(l_d, *stream); + sycl::free(m_d, *stream); +} + + +void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * V = dst->src[2]; + + switch (Q->ne[0]) { + case 64: + GGML_ASSERT(V->ne[0] == 64); + ggml_sycl_op_flash_attn_2< 64, 64>(ctx, dst); + break; + case 80: + GGML_ASSERT(V->ne[0] == 80); + ggml_sycl_op_flash_attn_2< 80, 80>(ctx, dst); + break; + case 96: + GGML_ASSERT(V->ne[0] == 96); + ggml_sycl_op_flash_attn_2< 96, 96>(ctx, dst); + break; + case 112: + GGML_ASSERT(V->ne[0] == 112); + ggml_sycl_op_flash_attn_2<112, 112>(ctx, dst); + break; + case 128: + GGML_ASSERT(V->ne[0] == 128); + ggml_sycl_op_flash_attn_2<128, 128>(ctx, dst); + break; + case 256: + GGML_ASSERT(V->ne[0] == 256); + ggml_sycl_op_flash_attn_2<256, 256>(ctx, dst); + break; + default: + GGML_ABORT("fatal error"); + break; + } +} + diff --git a/ggml/src/ggml-sycl/fattn.hpp b/ggml/src/ggml-sycl/fattn.hpp new file mode 100644 index 000000000..94188e18e --- /dev/null +++ b/ggml/src/ggml-sycl/fattn.hpp @@ -0,0 +1,13 @@ +#ifndef GGML_SYCL_FATTN_HPP +#define GGML_SYCL_FATTN_HPP + +#include "common.hpp" + +// Flash attention operation for SYCL backend +// This implements the Flash Attention algorithm optimized for SYCL devices +void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +// Check if flash attention is supported for given tensor +bool ggml_sycl_flash_attn_ext_supported(const ggml_tensor * dst); + +#endif // GGML_SYCL_FATTN_HPP diff --git a/ggml/src/ggml-sycl/fattn_common.hpp b/ggml/src/ggml-sycl/fattn_common.hpp new file mode 100644 index 000000000..53a99f44f --- /dev/null +++ b/ggml/src/ggml-sycl/fattn_common.hpp @@ -0,0 +1,12 @@ +#ifndef GGML_SYCL_FATTN_COMMON_HPP +#define GGML_SYCL_FATTN_COMMON_HPP + +template +inline void ggml_sycl_memcpy(float* dst, const float* src) { + #pragma unroll + for (int i = 0; i < N; ++i) { + dst[i] = src[i]; + } +} + +#endif // GGML_SYCL_FATTN_COMMON_HPP diff --git a/ggml/src/ggml-sycl/fattn_kernel.hpp b/ggml/src/ggml-sycl/fattn_kernel.hpp new file mode 100644 index 000000000..822947056 --- /dev/null +++ b/ggml/src/ggml-sycl/fattn_kernel.hpp @@ -0,0 +1,141 @@ +#ifndef GGML_SYCL_FATTN_KERNEL_HPP +#define GGML_SYCL_FATTN_KERNEL_HPP + +#include + +template +inline void flash_attn_mul_mat_QK_kernel( + sycl::nd_item<2> it, + const float * Q, ptrdiff_t q_row_stride, + const float * K, ptrdiff_t k_row_stride, + float * S, ptrdiff_t s_row_stride, + const int Br, const int Bc) { + + const int i = it.get_local_id(0); + if (i >= Br) { + return; + } + + const float * q_vec = Q + i * q_row_stride; + float * s_row = S + i * s_row_stride; + + for (int j = 0; j < Bc; ++j) { + const float * k_vec = K + j * k_row_stride; + float score = 0.0f; + +#pragma unroll + for (int k = 0; k < QKD; ++k) { + score += q_vec[k] * k_vec[k]; + } + + s_row[j] = score; + } +} + + +inline void flash_attn_softmax_kernel( + sycl::nd_item<2> it, + float * S, float * P, + float * m_local, float * l_local, + const int Br, const int Bc, + float * l_d, float * m_d +) { + const int li = it.get_local_id(0); + const int gi = it.get_group(0); + const int row = gi * Br + li; + + if (li >= Br) { + return; + } + + const int row_offset = li * Bc; + + float m_old = m_d[row]; + float l_old = l_d[row]; + + // 2. Block max + float m_block = -INFINITY; + for (int j = 0; j < Bc; ++j) { + const float s_ij = S[row_offset + j]; + m_block = sycl::fmax(m_block, s_ij); + } + + // 3. Block exp-sum + float l_block = 0.0f; + for (int j = 0; j < Bc; ++j) { + const float e = sycl::exp(S[row_offset + j] - m_block); + P[row_offset + j] = e; // temporary store + l_block += e; + } + + // 4. Merge block stats with global (streaming softmax) + float m_new; + float l_new; + + if (l_old == 0.0f && m_old == -INFINITY) { + // first block for this row + m_new = m_block; + l_new = l_block; + } else { + m_new = sycl::fmax(m_old, m_block); + + const float alpha = sycl::exp(m_old - m_new); + const float beta = sycl::exp(m_block - m_new); + + l_new = alpha * l_old + beta * l_block; + } + + // 5. Store updated global stats + m_d[row] = m_new; + l_d[row] = l_new; + + // 6. Convert local e_ij to global probabilities p_ij + float scale_block = 0.0f; + if (l_new > 0.0f) { + scale_block = sycl::exp(m_block - m_new) / l_new; + } + + for (int j = 0; j < Bc; ++j) { + P[row_offset + j] *= scale_block; + } + + // 7. Optional: keep local copies + m_local[li] = m_new; + l_local[li] = l_new; +} + + + + +template +inline void flash_attn_mul_mat_PV_kernel( + sycl::nd_item<2> it, + const float * P, ptrdiff_t p_row_stride, + const float * V, ptrdiff_t v_row_stride, + float * O, ptrdiff_t o_row_stride, + const int Br, const int Bc) { + + const int i = it.get_local_id(0); + if (i >= Br) { + return; + } + + const float * p_row = P + i * p_row_stride; + float * o_row = O + i * o_row_stride; + + for (int j = 0; j < VD; ++j) { + float acc = 0.0f; + +#pragma unroll + for (int k = 0; k < Bc; ++k) { + const float * v_row = V + k * v_row_stride; + acc += p_row[k] * v_row[j]; + } + + o_row[j] = acc; + } +} + +#endif // GGML_SYCL_FATTN_KERNEL_HPP + + diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 33f903507..a37e89d4e 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -3839,6 +3839,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_OP_ARGSORT: ggml_sycl_argsort(ctx, dst); break; + case GGML_OP_FLASH_ATTN_EXT: + ggml_sycl_op_flash_attn(ctx, dst); + break; case GGML_OP_TIMESTEP_EMBEDDING: ggml_sycl_op_timestep_embedding(ctx, dst); break; @@ -4501,6 +4504,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_MEAN: case GGML_OP_ARGSORT: return ggml_is_contiguous(op->src[0]); + case GGML_OP_FLASH_ATTN_EXT: + return ggml_sycl_flash_attn_ext_supported(op); case GGML_OP_POOL_2D: case GGML_OP_ACC: return true;