@@ -19,6 +19,7 @@ void flash_attn_impl(hexagon::tensor * out,
1919 const hexagon::tensor * k,
2020 const hexagon::tensor * v,
2121 const hexagon::tensor * mask,
22+ const hexagon::tensor * sinks,
2223 hexagon::compute_params * params) {
2324 static_assert (3 <= hexagon::kMaxParamsCount , " flash_attn op params count exceeds max params count" );
2425
@@ -92,11 +93,12 @@ void flash_attn_impl(hexagon::tensor * out,
9293 }
9394
9495 DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_WITH_MULTI_SUB_PROC (out, params->get_thread_index (), flash_attn);
95- const uint8_t * q_ptr = q->get_read_buffer ();
96- const uint8_t * k_ptr = k->get_read_buffer ();
97- const uint8_t * v_ptr = v->get_read_buffer ();
98- const uint8_t * mask_ptr = mask ? mask->get_read_buffer () : nullptr ;
99- float * VKQ32 = reinterpret_cast <float *>(cache_ptr); // FP32 VKQ accumulator
96+ const uint8_t * q_ptr = q->get_read_buffer ();
97+ const uint8_t * k_ptr = k->get_read_buffer ();
98+ const uint8_t * v_ptr = v->get_read_buffer ();
99+ const uint8_t * mask_ptr = mask ? mask->get_read_buffer () : nullptr ;
100+ const uint8_t * sinks_ptr = sinks ? sinks->get_read_buffer () : nullptr ;
101+ float * VKQ32 = reinterpret_cast <float *>(cache_ptr); // FP32 VKQ accumulator
100102 auto * VKQ16 = reinterpret_cast <npu_device_fp16_t *>(VKQ32 + aligned_dv); // (temporary) FP16 VKQ accumulator
101103 auto * Q_q = reinterpret_cast <npu_device_fp16_t *>(
102104 VKQ32 + 2 * aligned_dv); // (temporary) buffer for Q converted to quantized/FP16
@@ -224,6 +226,22 @@ void flash_attn_impl(hexagon::tensor * out,
224226 }
225227 }
226228
229+ if (sinks_ptr) {
230+ const float s = reinterpret_cast <const float *>(sinks_ptr)[h];
231+
232+ float ms = 1 .0f ;
233+ float vs = 1 .0f ;
234+
235+ if (s > M) {
236+ ms = expf (M - s);
237+ hexagon::vec_scale_f32 (VKQ32, ms, VKQ32, DV);
238+ } else {
239+ vs = expf (s - M);
240+ }
241+
242+ S = S * ms + vs;
243+ }
244+
227245 // V /= S
228246 const float S_inv = 1 .0f / S;
229247 hexagon::vec_scale_f32 (VKQ32, S_inv, VKQ32, DV);
@@ -253,20 +271,21 @@ bool flash_attn_f32(tensor * out, compute_params * params) {
253271 return false ;
254272 }
255273
256- const auto * q = out->get_src (0 );
257- const auto * k = out->get_src (1 );
258- const auto * v = out->get_src (2 );
259- const auto * mask = out->get_src (3 );
260- if (!q || !k || !v || !mask) {
274+ const auto * q = out->get_src (0 );
275+ const auto * k = out->get_src (1 );
276+ const auto * v = out->get_src (2 );
277+ if (!q || !k || !v) {
261278 DEVICE_LOG_DEBUG (
262279 " invalid src tensors: q: %p, k: %p, v: %p, mask: %p\n " , (void *) q, (void *) k, (void *) v, (void *) mask);
263280 return false ;
264281 }
265282
283+ const auto * mask = out->get_src (3 );
284+ const auto * sinks = out->get_src (4 );
266285 if (k->get_type () == NPU_DATA_TYPE_F16) {
267- flash_attn_impl<true >(out, q, k, v, mask, params);
286+ flash_attn_impl<true >(out, q, k, v, mask, sinks, params);
268287 } else {
269- flash_attn_impl<false >(out, q, k, v, mask, params);
288+ flash_attn_impl<false >(out, q, k, v, mask, sinks, params);
270289 }
271290 return true ;
272291}
0 commit comments