Skip to content

Commit 6c517f1

Browse files
committed
feat: add sinks tensor support in fa impl
1 parent aed9b4f commit 6c517f1

File tree

2 files changed

+32
-13
lines changed

2 files changed

+32
-13
lines changed

ggml/src/ggml-qnn/npu/device/op_flash_attn.cpp

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ void flash_attn_impl(hexagon::tensor * out,
1919
const hexagon::tensor * k,
2020
const hexagon::tensor * v,
2121
const hexagon::tensor * mask,
22+
const hexagon::tensor * sinks,
2223
hexagon::compute_params * params) {
2324
static_assert(3 <= hexagon::kMaxParamsCount, "flash_attn op params count exceeds max params count");
2425

@@ -92,11 +93,12 @@ void flash_attn_impl(hexagon::tensor * out,
9293
}
9394

9495
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_WITH_MULTI_SUB_PROC(out, params->get_thread_index(), flash_attn);
95-
const uint8_t * q_ptr = q->get_read_buffer();
96-
const uint8_t * k_ptr = k->get_read_buffer();
97-
const uint8_t * v_ptr = v->get_read_buffer();
98-
const uint8_t * mask_ptr = mask ? mask->get_read_buffer() : nullptr;
99-
float * VKQ32 = reinterpret_cast<float *>(cache_ptr); // FP32 VKQ accumulator
96+
const uint8_t * q_ptr = q->get_read_buffer();
97+
const uint8_t * k_ptr = k->get_read_buffer();
98+
const uint8_t * v_ptr = v->get_read_buffer();
99+
const uint8_t * mask_ptr = mask ? mask->get_read_buffer() : nullptr;
100+
const uint8_t * sinks_ptr = sinks ? sinks->get_read_buffer() : nullptr;
101+
float * VKQ32 = reinterpret_cast<float *>(cache_ptr); // FP32 VKQ accumulator
100102
auto * VKQ16 = reinterpret_cast<npu_device_fp16_t *>(VKQ32 + aligned_dv); // (temporary) FP16 VKQ accumulator
101103
auto * Q_q = reinterpret_cast<npu_device_fp16_t *>(
102104
VKQ32 + 2 * aligned_dv); // (temporary) buffer for Q converted to quantized/FP16
@@ -224,6 +226,22 @@ void flash_attn_impl(hexagon::tensor * out,
224226
}
225227
}
226228

229+
if (sinks_ptr) {
230+
const float s = reinterpret_cast<const float *>(sinks_ptr)[h];
231+
232+
float ms = 1.0f;
233+
float vs = 1.0f;
234+
235+
if (s > M) {
236+
ms = expf(M - s);
237+
hexagon::vec_scale_f32(VKQ32, ms, VKQ32, DV);
238+
} else {
239+
vs = expf(s - M);
240+
}
241+
242+
S = S * ms + vs;
243+
}
244+
227245
// V /= S
228246
const float S_inv = 1.0f / S;
229247
hexagon::vec_scale_f32(VKQ32, S_inv, VKQ32, DV);
@@ -253,20 +271,21 @@ bool flash_attn_f32(tensor * out, compute_params * params) {
253271
return false;
254272
}
255273

256-
const auto * q = out->get_src(0);
257-
const auto * k = out->get_src(1);
258-
const auto * v = out->get_src(2);
259-
const auto * mask = out->get_src(3);
260-
if (!q || !k || !v || !mask) {
274+
const auto * q = out->get_src(0);
275+
const auto * k = out->get_src(1);
276+
const auto * v = out->get_src(2);
277+
if (!q || !k || !v) {
261278
DEVICE_LOG_DEBUG(
262279
"invalid src tensors: q: %p, k: %p, v: %p, mask: %p\n", (void *) q, (void *) k, (void *) v, (void *) mask);
263280
return false;
264281
}
265282

283+
const auto * mask = out->get_src(3);
284+
const auto * sinks = out->get_src(4);
266285
if (k->get_type() == NPU_DATA_TYPE_F16) {
267-
flash_attn_impl<true>(out, q, k, v, mask, params);
286+
flash_attn_impl<true>(out, q, k, v, mask, sinks, params);
268287
} else {
269-
flash_attn_impl<false>(out, q, k, v, mask, params);
288+
flash_attn_impl<false>(out, q, k, v, mask, sinks, params);
270289
}
271290
return true;
272291
}

ggml/src/ggml-qnn/npu/idl/hexagon_npu.idl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#include "remote.idl"
44

55
const uint32_t DEVICE_TENSOR_MAX_DIMS = 4;
6-
const uint32_t DEVICE_TENSOR_MAX_SRC = 4;
6+
const uint32_t DEVICE_TENSOR_MAX_SRC = 5;
77
const uint32_t DEVICE_TENSOR_MAX_OP_PARAMS = 16;
88
const uint32_t QUANT_BLOCK_SIZE = 32;
99
const uint32_t QUANT_K_BLOCK_SIZE = 256;

0 commit comments

Comments
 (0)