@@ -93,122 +93,4 @@ bool ggml_sycl_flash_attn_ext_supported(const ggml_tensor * dst) {
9393 }
9494
9595 return true ;
96- }
97-
98-
99-
100-
101- // #include "flash-attn-sycl.h"
102-
103- // #include "kernels/flash-attn-kernel.h"
104-
105- // #include <cmath>
106- // #include <cstring>
107- // #include <limits>
108- // #include <sycl/sycl.hpp>
109-
110- // #ifndef GGML_USE_SYCL
111- // #warning "SYCL not enabled. This source file will be ignored."
112- // #else
113-
114- // #define FLASH_ATTN_BR_MAX 32
115- // #define FLASH_ATTN_BC_MAX 32
116-
117- // // RAII helper to free device memory automatically
118- // class SyclDeviceBuffer {
119- // public:
120- // SyclDeviceBuffer(sycl::queue & q, size_t count)
121- // : queue(q), ptr(nullptr), size(count) {
122- // ptr = sycl::malloc_device<float>(count, queue);
123- // }
124-
125- // ~SyclDeviceBuffer() {
126- // if (ptr) {
127- // sycl::free(ptr, queue);
128- // }
129- // }
130-
131- // float * get() const { return ptr; }
132- // bool valid() const { return ptr != nullptr; }
133-
134- // private:
135- // sycl::queue & queue;
136- // float * ptr;
137- // size_t size;
138- // };
139-
140- // // Flash Attention: https://arxiv.org/abs/2205.14135
141- // void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
142- // const ggml_tensor * Q = dst->src[0];
143- // const ggml_tensor * K = dst->src[1];
144- // const ggml_tensor * V = dst->src[2];
145- // const ggml_tensor * mask = dst->src[3];
146-
147- // GGML_ASSERT(Q && K && V && dst);
148-
149- // if (Q->type != GGML_TYPE_F32 || K->type != GGML_TYPE_F32 || V->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) {
150- // fprintf(stderr, "[SYCL] FLASH-ATTENTION: tensor type not supported (Q=%d, K=%d, V=%d, dst=%d)\n", Q->type, K->type, V->type, dst->type);
151- // return;
152- // }
153-
154- // const float * q_data = static_cast<const float *>(Q->data);
155- // const float * k_data = static_cast<const float *>(K->data);
156- // const float * v_data = static_cast<const float *>(V->data);
157- // float * dst_data = static_cast<float *>(dst->data);
158-
159- // sycl::queue & stream = *ctx.stream();
160-
161- // const int64_t d = Q->ne[0];
162- // const int64_t N = Q->ne[1];
163-
164- // float scale, max_bias, logit_softcap;
165- // std::memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
166- // std::memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
167- // std::memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float));
168-
169- // const bool masked = (mask != nullptr);
170-
171- // const int Br = std::min((int) FLASH_ATTN_BR_MAX, (int) N);
172- // const int Bc = std::min((int) FLASH_ATTN_BC_MAX, (int) N);
173-
174- // const int Tr = (N + Br - 1) / Br;
175- // const int Tc = (N + Bc - 1) / Bc;
176-
177- // SyclDeviceBuffer l_buf(stream, N);
178- // SyclDeviceBuffer m_buf(stream, N);
179-
180- // if (!l_buf.valid() || !m_buf.valid()) {
181- // fprintf(stderr, "[SYCL] FLASH-ATTENTION: failed to allocate device buffers.\n");
182- // return;
183- // }
184-
185- // stream.fill(l_buf.get(), 0.0f, N).wait();
186- // stream.fill(m_buf.get(), -std::numeric_limits<float>::infinity(), N).wait();
187- // stream.fill(dst_data, 0.0f, ggml_nelements(dst)).wait();
188-
189- // for (int j = 0; j < Tc; ++j) {
190- // stream.submit([&](sycl::handler & cgh) {
191- // cgh.parallel_for(sycl::range<1>(Tr), [=](sycl::id<1> idx) {
192- // const int i = idx[0];
193- // flash_attn_tiled_kernel<FLASH_ATTN_BR_MAX, FLASH_ATTN_BC_MAX>(
194- // q_data, k_data, v_data, dst_data, l_buf.get(), m_buf.get(),
195- // i, j, Br, Bc, N, d, masked, scale);
196- // });
197- // });
198- // }
199- // stream.wait();
200- // }
201-
202- // bool ggml_sycl_flash_attn_ext_supported(const ggml_tensor * dst) {
203- // const ggml_tensor * Q = dst->src[0];
204- // const ggml_tensor * K = dst->src[1];
205- // const ggml_tensor * V = dst->src[2];
206-
207- // if (!Q || !K || !V) return false;
208- // if (Q->type != GGML_TYPE_F32 || K->type != GGML_TYPE_F32 || V->type != GGML_TYPE_F32) return false;
209- // if (dst->type != GGML_TYPE_F32) return false;
210-
211- // return true;
212- // }
213-
214- // #endif
96+ }
0 commit comments