@@ -94,3 +94,121 @@ bool ggml_sycl_flash_attn_ext_supported(const ggml_tensor * dst) {
9494
9595 return true ;
9696}
97+
98+
99+
100+
101+ // #include "flash-attn-sycl.h"
102+
103+ // #include "kernels/flash-attn-kernel.h"
104+
105+ // #include <cmath>
106+ // #include <cstring>
107+ // #include <limits>
108+ // #include <sycl/sycl.hpp>
109+
110+ // #ifndef GGML_USE_SYCL
111+ // #warning "SYCL not enabled. This source file will be ignored."
112+ // #else
113+
114+ // #define FLASH_ATTN_BR_MAX 32
115+ // #define FLASH_ATTN_BC_MAX 32
116+
117+ // // RAII helper to free device memory automatically
118+ // class SyclDeviceBuffer {
119+ // public:
120+ // SyclDeviceBuffer(sycl::queue & q, size_t count)
121+ // : queue(q), ptr(nullptr), size(count) {
122+ // ptr = sycl::malloc_device<float>(count, queue);
123+ // }
124+
125+ // ~SyclDeviceBuffer() {
126+ // if (ptr) {
127+ // sycl::free(ptr, queue);
128+ // }
129+ // }
130+
131+ // float * get() const { return ptr; }
132+ // bool valid() const { return ptr != nullptr; }
133+
134+ // private:
135+ // sycl::queue & queue;
136+ // float * ptr;
137+ // size_t size;
138+ // };
139+
140+ // // Flash Attention: https://arxiv.org/abs/2205.14135
141+ // void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
142+ // const ggml_tensor * Q = dst->src[0];
143+ // const ggml_tensor * K = dst->src[1];
144+ // const ggml_tensor * V = dst->src[2];
145+ // const ggml_tensor * mask = dst->src[3];
146+
147+ // GGML_ASSERT(Q && K && V && dst);
148+
149+ // if (Q->type != GGML_TYPE_F32 || K->type != GGML_TYPE_F32 || V->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) {
150+ // fprintf(stderr, "[SYCL] FLASH-ATTENTION: tensor type not supported (Q=%d, K=%d, V=%d, dst=%d)\n", Q->type, K->type, V->type, dst->type);
151+ // return;
152+ // }
153+
154+ // const float * q_data = static_cast<const float *>(Q->data);
155+ // const float * k_data = static_cast<const float *>(K->data);
156+ // const float * v_data = static_cast<const float *>(V->data);
157+ // float * dst_data = static_cast<float *>(dst->data);
158+
159+ // sycl::queue & stream = *ctx.stream();
160+
161+ // const int64_t d = Q->ne[0];
162+ // const int64_t N = Q->ne[1];
163+
164+ // float scale, max_bias, logit_softcap;
165+ // std::memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
166+ // std::memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
167+ // std::memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float));
168+
169+ // const bool masked = (mask != nullptr);
170+
171+ // const int Br = std::min((int) FLASH_ATTN_BR_MAX, (int) N);
172+ // const int Bc = std::min((int) FLASH_ATTN_BC_MAX, (int) N);
173+
174+ // const int Tr = (N + Br - 1) / Br;
175+ // const int Tc = (N + Bc - 1) / Bc;
176+
177+ // SyclDeviceBuffer l_buf(stream, N);
178+ // SyclDeviceBuffer m_buf(stream, N);
179+
180+ // if (!l_buf.valid() || !m_buf.valid()) {
181+ // fprintf(stderr, "[SYCL] FLASH-ATTENTION: failed to allocate device buffers.\n");
182+ // return;
183+ // }
184+
185+ // stream.fill(l_buf.get(), 0.0f, N).wait();
186+ // stream.fill(m_buf.get(), -std::numeric_limits<float>::infinity(), N).wait();
187+ // stream.fill(dst_data, 0.0f, ggml_nelements(dst)).wait();
188+
189+ // for (int j = 0; j < Tc; ++j) {
190+ // stream.submit([&](sycl::handler & cgh) {
191+ // cgh.parallel_for(sycl::range<1>(Tr), [=](sycl::id<1> idx) {
192+ // const int i = idx[0];
193+ // flash_attn_tiled_kernel<FLASH_ATTN_BR_MAX, FLASH_ATTN_BC_MAX>(
194+ // q_data, k_data, v_data, dst_data, l_buf.get(), m_buf.get(),
195+ // i, j, Br, Bc, N, d, masked, scale);
196+ // });
197+ // });
198+ // }
199+ // stream.wait();
200+ // }
201+
202+ // bool ggml_sycl_flash_attn_ext_supported(const ggml_tensor * dst) {
203+ // const ggml_tensor * Q = dst->src[0];
204+ // const ggml_tensor * K = dst->src[1];
205+ // const ggml_tensor * V = dst->src[2];
206+
207+ // if (!Q || !K || !V) return false;
208+ // if (Q->type != GGML_TYPE_F32 || K->type != GGML_TYPE_F32 || V->type != GGML_TYPE_F32) return false;
209+ // if (dst->type != GGML_TYPE_F32) return false;
210+
211+ // return true;
212+ // }
213+
214+ // #endif
0 commit comments