Skip to content

Commit fc0e041

Browse files
ye-NXsafranowith
andcommitted
sycl: initialize flash-attention implementation
Co-authored-by: safranowith <[email protected]> Co-authored-by: ye-NX <[email protected]>
1 parent 6de8ed7 commit fc0e041

File tree

5 files changed

+230
-0
lines changed

5 files changed

+230
-0
lines changed

ggml/src/ggml-sycl/CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,15 @@ file(GLOB GGML_HEADERS_SYCL "*.hpp")
2727
file(GLOB GGML_SOURCES_SYCL "*.cpp")
2828
target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL} ${GGML_SOURCES_SYCL})
2929

30+
# Include flash-attn sources (SYCL optimized flash attention implementation)
31+
file(GLOB GGML_HEADERS_SYCL_FLASH "flash-attn/*.h" "flash-attn/*.hpp")
32+
file(GLOB GGML_SOURCES_SYCL_FLASH "flash-attn/*.cpp" "flash-attn/*.c")
33+
target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL_FLASH} ${GGML_SOURCES_SYCL_FLASH})
34+
35+
# Also include kernel headers under flash-attn/kernels
36+
file(GLOB GGML_HEADERS_SYCL_FLASH_KERNELS "flash-attn/kernels/*.h" "flash-attn/kernels/*.hpp")
37+
target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL_FLASH_KERNELS})
38+
3039
if (WIN32)
3140
# To generate a Visual Studio solution, using Intel C++ Compiler for ggml-sycl is mandatory
3241
if( ${CMAKE_GENERATOR} MATCHES "Visual Studio" AND NOT (${CMAKE_GENERATOR_TOOLSET} MATCHES "Intel C"))
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
#include "flash-attn-sycl.h"
2+
3+
#include "kernels/flash-attn-kernel.h"
4+
5+
#include <cmath>
6+
#include <cstring>
7+
#include <limits>
8+
#include <sycl/sycl.hpp>
9+
10+
#define FLASH_ATTN_BR_MAX 32
11+
#define FLASH_ATTN_BC_MAX 32
12+
13+
// Flash Attention: https://arxiv.org/abs/2205.14135
14+
void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
15+
const ggml_tensor * Q = dst->src[0];
16+
const ggml_tensor * K = dst->src[1];
17+
const ggml_tensor * V = dst->src[2];
18+
const ggml_tensor * mask = dst->src[3];
19+
20+
GGML_ASSERT(Q != nullptr);
21+
GGML_ASSERT(K != nullptr);
22+
GGML_ASSERT(V != nullptr);
23+
GGML_ASSERT(dst != nullptr);
24+
25+
if (Q->type != GGML_TYPE_F32 || K->type != GGML_TYPE_F32 || V->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) {
26+
fprintf(stderr, "[SYCL] FLASH-ATTENTION: tensor type not supported (Q=%d, K=%d, V=%d, dst=%d)\n", Q->type, K->type, V->type, dst->type);
27+
return;
28+
}
29+
30+
const float * Q_d = (const float *) Q->data;
31+
const float * K_d = (const float *) K->data;
32+
const float * V_d = (const float *) V->data;
33+
float * dst_d = (float *) dst->data;
34+
35+
dpct::queue_ptr stream = ctx.stream();
36+
37+
const int64_t d = Q->ne[0];
38+
const int64_t N = Q->ne[1];
39+
40+
float scale;
41+
float max_bias;
42+
float logit_softcap;
43+
44+
std::memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
45+
std::memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
46+
std::memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float));
47+
48+
const bool masked = (mask != nullptr);
49+
50+
const int Br = std::min((int) FLASH_ATTN_BR_MAX, (int) N);
51+
const int Bc = std::min((int) FLASH_ATTN_BC_MAX, (int) N);
52+
53+
const int Tr = (N + Br - 1) / Br;
54+
const int Tc = (N + Bc - 1) / Bc;
55+
56+
float * l_d = (float *) sycl::malloc_device(N * sizeof(float), *stream);
57+
float * m_d = (float *) sycl::malloc_device(N * sizeof(float), *stream);
58+
59+
stream->fill(l_d, 0.0f, N);
60+
stream->fill(m_d, -std::numeric_limits<float>::infinity(), N);
61+
stream->fill(dst_d, 0.0f, N * d);
62+
stream->wait();
63+
64+
for (int j = 0; j < Tc; ++j) {
65+
stream->submit([&](sycl::handler & cgh) {
66+
cgh.parallel_for(sycl::range<1>(Tr), [=](sycl::id<1> idx) {
67+
const int i = idx[0];
68+
flash_attn_tiled_kernel<FLASH_ATTN_BR_MAX, FLASH_ATTN_BC_MAX>(Q_d, K_d, V_d, dst_d, l_d, m_d, i, j, Br,
69+
Bc, N, d, masked, scale);
70+
});
71+
});
72+
}
73+
74+
stream->wait();
75+
76+
sycl::free(l_d, *stream);
77+
sycl::free(m_d, *stream);
78+
}
79+
80+
bool ggml_sycl_flash_attn_ext_supported(const ggml_tensor * dst) {
81+
const ggml_tensor * Q = dst->src[0];
82+
const ggml_tensor * K = dst->src[1];
83+
const ggml_tensor * V = dst->src[2];
84+
85+
if (Q == nullptr || K == nullptr || V == nullptr) {
86+
return false;
87+
}
88+
89+
if (Q->type != GGML_TYPE_F32 || K->type != GGML_TYPE_F32 || V->type != GGML_TYPE_F32) {
90+
return false;
91+
}
92+
93+
if (dst->type != GGML_TYPE_F32) {
94+
return false;
95+
}
96+
97+
return true;
98+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#pragma once
2+
3+
#include "../common.hpp"
4+
5+
// Flash attention operation for SYCL backend
6+
// This implements the Flash Attention algorithm optimized for SYCL devices
7+
void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
8+
9+
// Check if flash attention is supported for given tensor
10+
bool ggml_sycl_flash_attn_ext_supported(const ggml_tensor * dst);
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
#pragma once
2+
3+
#include <sycl/sycl.hpp>
4+
5+
template <int Br_MAX = 32, int Bc_MAX = 32>
6+
inline void flash_attn_tiled_kernel(const float * Q,
7+
const float * K,
8+
const float * V,
9+
float * O,
10+
float * l,
11+
float * m,
12+
const int i_block,
13+
const int j_block,
14+
const int Br,
15+
const int Bc,
16+
const int N,
17+
const int d,
18+
const bool masked,
19+
const float scale) {
20+
const int i_start = i_block * Br;
21+
const int j_start = j_block * Bc;
22+
23+
float S[Br_MAX][Bc_MAX];
24+
float P[Br_MAX][Bc_MAX];
25+
float m_local[Br_MAX];
26+
float l_local[Br_MAX];
27+
28+
for (int qi = 0; qi < Br; ++qi) {
29+
const int q_row = i_start + qi;
30+
if (q_row >= N) {
31+
continue;
32+
}
33+
34+
for (int kj = 0; kj < Bc; ++kj) {
35+
const int k_row = j_start + kj;
36+
if (k_row >= N) {
37+
S[qi][kj] = -INFINITY;
38+
continue;
39+
}
40+
41+
if (masked && k_row > q_row) {
42+
S[qi][kj] = -INFINITY;
43+
continue;
44+
}
45+
46+
float score = 0.0f;
47+
for (int k = 0; k < d; ++k) {
48+
score += Q[q_row * d + k] * K[k_row * d + k];
49+
}
50+
S[qi][kj] = score * scale;
51+
}
52+
}
53+
54+
for (int qi = 0; qi < Br; ++qi) {
55+
const int q_row = i_start + qi;
56+
if (q_row >= N) {
57+
continue;
58+
}
59+
60+
m_local[qi] = -INFINITY;
61+
for (int kj = 0; kj < Bc; ++kj) {
62+
if (j_start + kj < N) {
63+
m_local[qi] = sycl::fmax(m_local[qi], S[qi][kj]);
64+
}
65+
}
66+
67+
l_local[qi] = 0.0f;
68+
for (int kj = 0; kj < Bc; ++kj) {
69+
if (j_start + kj < N && !sycl::isinf(S[qi][kj])) {
70+
P[qi][kj] = sycl::exp(S[qi][kj] - m_local[qi]);
71+
l_local[qi] += P[qi][kj];
72+
} else {
73+
P[qi][kj] = 0.0f;
74+
}
75+
}
76+
}
77+
78+
for (int qi = 0; qi < Br; ++qi) {
79+
const int q_row = i_start + qi;
80+
if (q_row >= N) {
81+
continue;
82+
}
83+
84+
const float m_old = m[q_row];
85+
const float m_new = sycl::fmax(m_old, m_local[qi]);
86+
const float l_old = l[q_row];
87+
const float l_new = sycl::exp(m_old - m_new) * l_old + sycl::exp(m_local[qi] - m_new) * l_local[qi];
88+
89+
const float correction_old = sycl::exp(m_old - m_new);
90+
const float correction_new = sycl::exp(m_local[qi] - m_new);
91+
92+
for (int k = 0; k < d; ++k) {
93+
float pv = 0.0f;
94+
for (int kj = 0; kj < Bc; ++kj) {
95+
const int v_row = j_start + kj;
96+
if (v_row < N) {
97+
pv += P[qi][kj] * V[v_row * d + k];
98+
}
99+
}
100+
101+
const int o_idx = q_row * d + k;
102+
O[o_idx] = (correction_old * O[o_idx] + correction_new * pv) / l_new;
103+
}
104+
105+
l[q_row] = l_new;
106+
m[q_row] = m_new;
107+
}
108+
}

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3839,6 +3839,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
38393839
case GGML_OP_ARGSORT:
38403840
ggml_sycl_argsort(ctx, dst);
38413841
break;
3842+
case GGML_OP_FLASH_ATTN_EXT:
3843+
ggml_sycl_op_flash_attn(ctx, dst);
3844+
break;
38423845
case GGML_OP_TIMESTEP_EMBEDDING:
38433846
ggml_sycl_op_timestep_embedding(ctx, dst);
38443847
break;
@@ -4501,6 +4504,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
45014504
case GGML_OP_MEAN:
45024505
case GGML_OP_ARGSORT:
45034506
return ggml_is_contiguous(op->src[0]);
4507+
case GGML_OP_FLASH_ATTN_EXT:
4508+
return ggml_sycl_flash_attn_ext_supported(op);
45044509
case GGML_OP_POOL_2D:
45054510
case GGML_OP_ACC:
45064511
return true;

0 commit comments

Comments
 (0)