Skip to content

Commit 8e8fb57

Browse files
safranowithye-NX
andcommitted
add include in ggml-sycl.cpp
Co-authored-by: safranowith <[email protected]> Co-authored-by: ye-NX <[email protected]>
1 parent af5b644 commit 8e8fb57

File tree

3 files changed

+122
-2
lines changed

3 files changed

+122
-2
lines changed

ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.cpp

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,121 @@ bool ggml_sycl_flash_attn_ext_supported(const ggml_tensor * dst) {
9494

9595
return true;
9696
}
97+
98+
99+
100+
101+
// #include "flash-attn-sycl.h"
102+
103+
// #include "kernels/flash-attn-kernel.h"
104+
105+
// #include <cmath>
106+
// #include <cstring>
107+
// #include <limits>
108+
// #include <sycl/sycl.hpp>
109+
110+
// #ifndef GGML_USE_SYCL
111+
// #warning "SYCL not enabled. This source file will be ignored."
112+
// #else
113+
114+
// #define FLASH_ATTN_BR_MAX 32
115+
// #define FLASH_ATTN_BC_MAX 32
116+
117+
// // RAII helper to free device memory automatically
118+
// class SyclDeviceBuffer {
119+
// public:
120+
// SyclDeviceBuffer(sycl::queue & q, size_t count)
121+
// : queue(q), ptr(nullptr), size(count) {
122+
// ptr = sycl::malloc_device<float>(count, queue);
123+
// }
124+
125+
// ~SyclDeviceBuffer() {
126+
// if (ptr) {
127+
// sycl::free(ptr, queue);
128+
// }
129+
// }
130+
131+
// float * get() const { return ptr; }
132+
// bool valid() const { return ptr != nullptr; }
133+
134+
// private:
135+
// sycl::queue & queue;
136+
// float * ptr;
137+
// size_t size;
138+
// };
139+
140+
// // Flash Attention: https://arxiv.org/abs/2205.14135
141+
// void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
142+
// const ggml_tensor * Q = dst->src[0];
143+
// const ggml_tensor * K = dst->src[1];
144+
// const ggml_tensor * V = dst->src[2];
145+
// const ggml_tensor * mask = dst->src[3];
146+
147+
// GGML_ASSERT(Q && K && V && dst);
148+
149+
// if (Q->type != GGML_TYPE_F32 || K->type != GGML_TYPE_F32 || V->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) {
150+
// fprintf(stderr, "[SYCL] FLASH-ATTENTION: tensor type not supported (Q=%d, K=%d, V=%d, dst=%d)\n", Q->type, K->type, V->type, dst->type);
151+
// return;
152+
// }
153+
154+
// const float * q_data = static_cast<const float *>(Q->data);
155+
// const float * k_data = static_cast<const float *>(K->data);
156+
// const float * v_data = static_cast<const float *>(V->data);
157+
// float * dst_data = static_cast<float *>(dst->data);
158+
159+
// sycl::queue & stream = *ctx.stream();
160+
161+
// const int64_t d = Q->ne[0];
162+
// const int64_t N = Q->ne[1];
163+
164+
// float scale, max_bias, logit_softcap;
165+
// std::memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
166+
// std::memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
167+
// std::memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float));
168+
169+
// const bool masked = (mask != nullptr);
170+
171+
// const int Br = std::min((int) FLASH_ATTN_BR_MAX, (int) N);
172+
// const int Bc = std::min((int) FLASH_ATTN_BC_MAX, (int) N);
173+
174+
// const int Tr = (N + Br - 1) / Br;
175+
// const int Tc = (N + Bc - 1) / Bc;
176+
177+
// SyclDeviceBuffer l_buf(stream, N);
178+
// SyclDeviceBuffer m_buf(stream, N);
179+
180+
// if (!l_buf.valid() || !m_buf.valid()) {
181+
// fprintf(stderr, "[SYCL] FLASH-ATTENTION: failed to allocate device buffers.\n");
182+
// return;
183+
// }
184+
185+
// stream.fill(l_buf.get(), 0.0f, N).wait();
186+
// stream.fill(m_buf.get(), -std::numeric_limits<float>::infinity(), N).wait();
187+
// stream.fill(dst_data, 0.0f, ggml_nelements(dst)).wait();
188+
189+
// for (int j = 0; j < Tc; ++j) {
190+
// stream.submit([&](sycl::handler & cgh) {
191+
// cgh.parallel_for(sycl::range<1>(Tr), [=](sycl::id<1> idx) {
192+
// const int i = idx[0];
193+
// flash_attn_tiled_kernel<FLASH_ATTN_BR_MAX, FLASH_ATTN_BC_MAX>(
194+
// q_data, k_data, v_data, dst_data, l_buf.get(), m_buf.get(),
195+
// i, j, Br, Bc, N, d, masked, scale);
196+
// });
197+
// });
198+
// }
199+
// stream.wait();
200+
// }
201+
202+
// bool ggml_sycl_flash_attn_ext_supported(const ggml_tensor * dst) {
203+
// const ggml_tensor * Q = dst->src[0];
204+
// const ggml_tensor * K = dst->src[1];
205+
// const ggml_tensor * V = dst->src[2];
206+
207+
// if (!Q || !K || !V) return false;
208+
// if (Q->type != GGML_TYPE_F32 || K->type != GGML_TYPE_F32 || V->type != GGML_TYPE_F32) return false;
209+
// if (dst->type != GGML_TYPE_F32) return false;
210+
211+
// return true;
212+
// }
213+
214+
// #endif

ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
#include "../common.hpp"
44

5+
56
// Flash attention operation for SYCL backend
67
// This implements the Flash Attention algorithm optimized for SYCL devices
7-
void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
8+
void ggml_sycl_op_flash_attn( ggml_backend_sycl_context & ctx, ggml_tensor * dst);
89

910
// Check if flash attention is supported for given tensor
10-
bool ggml_sycl_flash_attn_ext_supported(const ggml_tensor * dst);
11+
bool ggml_sycl_flash_attn_ext_supported(const struct ggml_tensor * dst);

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
#include "ggml-sycl/element_wise.hpp"
4242
#include "ggml-sycl/presets.hpp"
4343
#include "ggml-sycl/gemm.hpp"
44+
#include "flash-attn/flash-attn-sycl.h"
4445
#include "ggml-sycl/set_rows.hpp"
4546
#include "ggml-sycl/set.hpp"
4647
#include "ggml-sycl/sycl_hw.hpp"

0 commit comments

Comments
 (0)