Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions ggml/src/ggml-sycl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@ file(GLOB GGML_HEADERS_SYCL "*.hpp")
file(GLOB GGML_SOURCES_SYCL "*.cpp")
target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL} ${GGML_SOURCES_SYCL})

# Include flash-attn sources (SYCL optimized flash attention implementation)
file(GLOB GGML_HEADERS_SYCL_FLASH "flash-attn/*.h" "flash-attn/*.hpp")
file(GLOB GGML_SOURCES_SYCL_FLASH "flash-attn/*.cpp" "flash-attn/*.c")
target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL_FLASH} ${GGML_SOURCES_SYCL_FLASH})

# Also include kernel headers under flash-attn/kernels
file(GLOB GGML_HEADERS_SYCL_FLASH_KERNELS "flash-attn/kernels/*.h" "flash-attn/kernels/*.hpp")
target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL_FLASH_KERNELS})

if (WIN32)
# To generate a Visual Studio solution, using Intel C++ Compiler for ggml-sycl is mandatory
if( ${CMAKE_GENERATOR} MATCHES "Visual Studio" AND NOT (${CMAKE_GENERATOR_TOOLSET} MATCHES "Intel C"))
Expand Down
98 changes: 98 additions & 0 deletions ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#include "flash-attn-sycl.h"

#include "kernels/flash-attn-kernel.h"

#include <cmath>
#include <cstring>
#include <limits>
#include <sycl/sycl.hpp>

#define FLASH_ATTN_BR_MAX 32
#define FLASH_ATTN_BC_MAX 32

// Flash Attention: https://arxiv.org/abs/2205.14135
void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
const ggml_tensor * mask = dst->src[3];

GGML_ASSERT(Q != nullptr);
GGML_ASSERT(K != nullptr);
GGML_ASSERT(V != nullptr);
GGML_ASSERT(dst != nullptr);

if (Q->type != GGML_TYPE_F32 || K->type != GGML_TYPE_F32 || V->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) {
fprintf(stderr, "[SYCL] FLASH-ATTENTION: tensor type not supported (Q=%d, K=%d, V=%d, dst=%d)\n", Q->type, K->type, V->type, dst->type);
return;
}

const float * Q_d = (const float *) Q->data;
const float * K_d = (const float *) K->data;
const float * V_d = (const float *) V->data;
float * dst_d = (float *) dst->data;

dpct::queue_ptr stream = ctx.stream();

const int64_t d = Q->ne[0];
const int64_t N = Q->ne[1];

float scale;
float max_bias;
float logit_softcap;

std::memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
std::memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
std::memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float));

const bool masked = (mask != nullptr);

const int Br = std::min((int) FLASH_ATTN_BR_MAX, (int) N);
const int Bc = std::min((int) FLASH_ATTN_BC_MAX, (int) N);

const int Tr = (N + Br - 1) / Br;
const int Tc = (N + Bc - 1) / Bc;

float * l_d = (float *) sycl::malloc_device(N * sizeof(float), *stream);
float * m_d = (float *) sycl::malloc_device(N * sizeof(float), *stream);

stream->fill(l_d, 0.0f, N);
stream->fill(m_d, -std::numeric_limits<float>::infinity(), N);
stream->fill(dst_d, 0.0f, N * d);
stream->wait();

for (int j = 0; j < Tc; ++j) {
stream->submit([&](sycl::handler & cgh) {
cgh.parallel_for(sycl::range<1>(Tr), [=](sycl::id<1> idx) {
const int i = idx[0];
flash_attn_tiled_kernel<FLASH_ATTN_BR_MAX, FLASH_ATTN_BC_MAX>(Q_d, K_d, V_d, dst_d, l_d, m_d, i, j, Br,
Bc, N, d, masked, scale);
});
});
}

stream->wait();

sycl::free(l_d, *stream);
sycl::free(m_d, *stream);
}

bool ggml_sycl_flash_attn_ext_supported(const ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];

if (Q == nullptr || K == nullptr || V == nullptr) {
return false;
}

if (Q->type != GGML_TYPE_F32 || K->type != GGML_TYPE_F32 || V->type != GGML_TYPE_F32) {
return false;
}

if (dst->type != GGML_TYPE_F32) {
return false;
}

return true;
}
10 changes: 10 additions & 0 deletions ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#pragma once

#include "../common.hpp"

// Flash attention operation for SYCL backend
// This implements the Flash Attention algorithm optimized for SYCL devices
void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst);

// Check if flash attention is supported for given tensor
bool ggml_sycl_flash_attn_ext_supported(const ggml_tensor * dst);
108 changes: 108 additions & 0 deletions ggml/src/ggml-sycl/flash-attn/kernels/flash-attn-kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
#pragma once

#include <sycl/sycl.hpp>

template <int Br_MAX = 32, int Bc_MAX = 32>
inline void flash_attn_tiled_kernel(const float * Q,
const float * K,
const float * V,
float * O,
float * l,
float * m,
const int i_block,
const int j_block,
const int Br,
const int Bc,
const int N,
const int d,
const bool masked,
const float scale) {
const int i_start = i_block * Br;
const int j_start = j_block * Bc;

float S[Br_MAX][Bc_MAX];
float P[Br_MAX][Bc_MAX];
float m_local[Br_MAX];
float l_local[Br_MAX];

for (int qi = 0; qi < Br; ++qi) {
const int q_row = i_start + qi;
if (q_row >= N) {
continue;
}

for (int kj = 0; kj < Bc; ++kj) {
const int k_row = j_start + kj;
if (k_row >= N) {
S[qi][kj] = -INFINITY;
continue;
}

if (masked && k_row > q_row) {
S[qi][kj] = -INFINITY;
continue;
}

float score = 0.0f;
for (int k = 0; k < d; ++k) {
score += Q[q_row * d + k] * K[k_row * d + k];
}
S[qi][kj] = score * scale;
}
}

for (int qi = 0; qi < Br; ++qi) {
const int q_row = i_start + qi;
if (q_row >= N) {
continue;
}

m_local[qi] = -INFINITY;
for (int kj = 0; kj < Bc; ++kj) {
if (j_start + kj < N) {
m_local[qi] = sycl::fmax(m_local[qi], S[qi][kj]);
}
}

l_local[qi] = 0.0f;
for (int kj = 0; kj < Bc; ++kj) {
if (j_start + kj < N && !sycl::isinf(S[qi][kj])) {
P[qi][kj] = sycl::exp(S[qi][kj] - m_local[qi]);
l_local[qi] += P[qi][kj];
} else {
P[qi][kj] = 0.0f;
}
}
}

for (int qi = 0; qi < Br; ++qi) {
const int q_row = i_start + qi;
if (q_row >= N) {
continue;
}

const float m_old = m[q_row];
const float m_new = sycl::fmax(m_old, m_local[qi]);
const float l_old = l[q_row];
const float l_new = sycl::exp(m_old - m_new) * l_old + sycl::exp(m_local[qi] - m_new) * l_local[qi];

const float correction_old = sycl::exp(m_old - m_new);
const float correction_new = sycl::exp(m_local[qi] - m_new);

for (int k = 0; k < d; ++k) {
float pv = 0.0f;
for (int kj = 0; kj < Bc; ++kj) {
const int v_row = j_start + kj;
if (v_row < N) {
pv += P[qi][kj] * V[v_row * d + k];
}
}

const int o_idx = q_row * d + k;
O[o_idx] = (correction_old * O[o_idx] + correction_new * pv) / l_new;
}

l[q_row] = l_new;
m[q_row] = m_new;
}
}
5 changes: 5 additions & 0 deletions ggml/src/ggml-sycl/ggml-sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3839,6 +3839,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
case GGML_OP_ARGSORT:
ggml_sycl_argsort(ctx, dst);
break;
case GGML_OP_FLASH_ATTN_EXT:
ggml_sycl_op_flash_attn(ctx, dst);
break;
case GGML_OP_TIMESTEP_EMBEDDING:
ggml_sycl_op_timestep_embedding(ctx, dst);
break;
Expand Down Expand Up @@ -4501,6 +4504,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_MEAN:
case GGML_OP_ARGSORT:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_FLASH_ATTN_EXT:
return ggml_sycl_flash_attn_ext_supported(op);
case GGML_OP_POOL_2D:
case GGML_OP_ACC:
return true;
Expand Down
Loading