Skip to content

Commit dcd7ca5

Browse files
safranowithye-NX
andcommitted
remove unrelated changes
Co-authored-by: safranowith <[email protected]> Co-authored-by: ye-NX <[email protected]>
1 parent 8e8fb57 commit dcd7ca5

File tree

1 file changed

+1
-119
lines changed

1 file changed

+1
-119
lines changed

ggml/src/ggml-sycl/flash-attn/flash-attn-sycl.cpp

Lines changed: 1 addition & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -93,122 +93,4 @@ bool ggml_sycl_flash_attn_ext_supported(const ggml_tensor * dst) {
9393
}
9494

9595
return true;
96-
}
97-
98-
99-
100-
101-
// #include "flash-attn-sycl.h"
102-
103-
// #include "kernels/flash-attn-kernel.h"
104-
105-
// #include <cmath>
106-
// #include <cstring>
107-
// #include <limits>
108-
// #include <sycl/sycl.hpp>
109-
110-
// #ifndef GGML_USE_SYCL
111-
// #warning "SYCL not enabled. This source file will be ignored."
112-
// #else
113-
114-
// #define FLASH_ATTN_BR_MAX 32
115-
// #define FLASH_ATTN_BC_MAX 32
116-
117-
// // RAII helper to free device memory automatically
118-
// class SyclDeviceBuffer {
119-
// public:
120-
// SyclDeviceBuffer(sycl::queue & q, size_t count)
121-
// : queue(q), ptr(nullptr), size(count) {
122-
// ptr = sycl::malloc_device<float>(count, queue);
123-
// }
124-
125-
// ~SyclDeviceBuffer() {
126-
// if (ptr) {
127-
// sycl::free(ptr, queue);
128-
// }
129-
// }
130-
131-
// float * get() const { return ptr; }
132-
// bool valid() const { return ptr != nullptr; }
133-
134-
// private:
135-
// sycl::queue & queue;
136-
// float * ptr;
137-
// size_t size;
138-
// };
139-
140-
// // Flash Attention: https://arxiv.org/abs/2205.14135
141-
// void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
142-
// const ggml_tensor * Q = dst->src[0];
143-
// const ggml_tensor * K = dst->src[1];
144-
// const ggml_tensor * V = dst->src[2];
145-
// const ggml_tensor * mask = dst->src[3];
146-
147-
// GGML_ASSERT(Q && K && V && dst);
148-
149-
// if (Q->type != GGML_TYPE_F32 || K->type != GGML_TYPE_F32 || V->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) {
150-
// fprintf(stderr, "[SYCL] FLASH-ATTENTION: tensor type not supported (Q=%d, K=%d, V=%d, dst=%d)\n", Q->type, K->type, V->type, dst->type);
151-
// return;
152-
// }
153-
154-
// const float * q_data = static_cast<const float *>(Q->data);
155-
// const float * k_data = static_cast<const float *>(K->data);
156-
// const float * v_data = static_cast<const float *>(V->data);
157-
// float * dst_data = static_cast<float *>(dst->data);
158-
159-
// sycl::queue & stream = *ctx.stream();
160-
161-
// const int64_t d = Q->ne[0];
162-
// const int64_t N = Q->ne[1];
163-
164-
// float scale, max_bias, logit_softcap;
165-
// std::memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
166-
// std::memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
167-
// std::memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float));
168-
169-
// const bool masked = (mask != nullptr);
170-
171-
// const int Br = std::min((int) FLASH_ATTN_BR_MAX, (int) N);
172-
// const int Bc = std::min((int) FLASH_ATTN_BC_MAX, (int) N);
173-
174-
// const int Tr = (N + Br - 1) / Br;
175-
// const int Tc = (N + Bc - 1) / Bc;
176-
177-
// SyclDeviceBuffer l_buf(stream, N);
178-
// SyclDeviceBuffer m_buf(stream, N);
179-
180-
// if (!l_buf.valid() || !m_buf.valid()) {
181-
// fprintf(stderr, "[SYCL] FLASH-ATTENTION: failed to allocate device buffers.\n");
182-
// return;
183-
// }
184-
185-
// stream.fill(l_buf.get(), 0.0f, N).wait();
186-
// stream.fill(m_buf.get(), -std::numeric_limits<float>::infinity(), N).wait();
187-
// stream.fill(dst_data, 0.0f, ggml_nelements(dst)).wait();
188-
189-
// for (int j = 0; j < Tc; ++j) {
190-
// stream.submit([&](sycl::handler & cgh) {
191-
// cgh.parallel_for(sycl::range<1>(Tr), [=](sycl::id<1> idx) {
192-
// const int i = idx[0];
193-
// flash_attn_tiled_kernel<FLASH_ATTN_BR_MAX, FLASH_ATTN_BC_MAX>(
194-
// q_data, k_data, v_data, dst_data, l_buf.get(), m_buf.get(),
195-
// i, j, Br, Bc, N, d, masked, scale);
196-
// });
197-
// });
198-
// }
199-
// stream.wait();
200-
// }
201-
202-
// bool ggml_sycl_flash_attn_ext_supported(const ggml_tensor * dst) {
203-
// const ggml_tensor * Q = dst->src[0];
204-
// const ggml_tensor * K = dst->src[1];
205-
// const ggml_tensor * V = dst->src[2];
206-
207-
// if (!Q || !K || !V) return false;
208-
// if (Q->type != GGML_TYPE_F32 || K->type != GGML_TYPE_F32 || V->type != GGML_TYPE_F32) return false;
209-
// if (dst->type != GGML_TYPE_F32) return false;
210-
211-
// return true;
212-
// }
213-
214-
// #endif
96+
}

0 commit comments

Comments
 (0)