@@ -14,15 +14,20 @@ inline float f16_to_f32(const npu_device_fp16_t src) {
1414
1515// From: ggml/src/ggml-cpu/ops.cpp
1616template <bool _IsKvF16>
17- void flash_attn_impl (hexagon::tensor * out, const hexagon::tensor * q, const hexagon::tensor * k,
18- const hexagon::tensor * v, const hexagon::tensor * mask, hexagon::compute_params * params) {
17+ void flash_attn_impl (hexagon::tensor * out,
18+ const hexagon::tensor * q,
19+ const hexagon::tensor * k,
20+ const hexagon::tensor * v,
21+ const hexagon::tensor * mask,
22+ hexagon::compute_params * params) {
1923 static_assert (3 <= hexagon::kMaxParamsCount , " flash_attn op params count exceeds max params count" );
2024
2125 constexpr const npu_device_tensor_data_type kKvDataType = _IsKvF16 ? NPU_DATA_TYPE_F16 : NPU_DATA_TYPE_F32;
2226
2327 if (k->get_type () != kKvDataType || v->get_type () != k->get_type ()) {
2428 DEVICE_LOG_ERROR (" flash_attn_impl: k and v must have same type, got k: %s, v: %s\n " ,
25- hexagon::get_type_name (k->get_type ()), hexagon::get_type_name (v->get_type ()));
29+ hexagon::get_type_name (k->get_type ()),
30+ hexagon::get_type_name (v->get_type ()));
2631 return ;
2732 }
2833
@@ -80,7 +85,8 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
8085 const auto out_rows_per_batch = out->get_ne (2 ) * out->get_ne (1 );
8186 uint8_t * dst_ptr = out->get_write_buffer ();
8287 if (!dst_ptr) {
83- DEVICE_LOG_ERROR (" flash_attn_impl: dst_ptr is not writable, tensor: %p, type: %s\n " , (void *) out,
88+ DEVICE_LOG_ERROR (" flash_attn_impl: dst_ptr is not writable, tensor: %p, type: %s\n " ,
89+ (void *) out,
8490 hexagon::get_type_name (out->get_type ()));
8591 return ;
8692 }
@@ -118,7 +124,8 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
118124
119125 const npu_device_fp16_t * mp =
120126 mask_ptr ? reinterpret_cast <const npu_device_fp16_t *>(mask_ptr + iq1 * mask->get_nb (1 ) +
121- (iq3 % mask->get_ne (2 )) * mask->get_nb (2 )) :
127+ (iq2 % mask->get_ne (2 )) * mask->get_nb (2 ) +
128+ (iq3 % mask->get_ne (3 )) * mask->get_nb (3 )) :
122129 nullptr ;
123130
124131 // k indices
@@ -251,8 +258,8 @@ bool flash_attn_f32(tensor * out, compute_params * params) {
251258 const auto * v = out->get_src (2 );
252259 const auto * mask = out->get_src (3 );
253260 if (!q || !k || !v || !mask) {
254- DEVICE_LOG_DEBUG (" invalid src tensors: q: %p, k: %p, v: %p, mask: %p \n " , ( void *) q, ( void *) k, ( void *) v,
255- (void *) mask);
261+ DEVICE_LOG_DEBUG (
262+ " invalid src tensors: q: %p, k: %p, v: %p, mask: %p \n " , ( void *) q, ( void *) k, ( void *) v, (void *) mask);
256263 return false ;
257264 }
258265
@@ -264,8 +271,11 @@ bool flash_attn_f32(tensor * out, compute_params * params) {
264271 return true ;
265272}
266273
267- bool is_flash_attn_supported (npu_device_tensor_op op, const npu_device_tensor_spec * dst,
268- const npu_device_tensor_spec * srcs, size_t src_len) {
274+ bool is_flash_attn_supported (const npu_device_tensor_op_spec * op_spec,
275+ const npu_device_tensor_spec * dst,
276+ const npu_device_tensor_spec * srcs,
277+ size_t src_len) {
278+ const auto op = op_spec->op ;
269279 if (op != NPU_OP_FLASH_ATTN) {
270280 DEVICE_LOG_DEBUG (" op is not NPU_OP_FLASH_ATTN: %d\n " , op);
271281 return false ;
@@ -295,7 +305,9 @@ bool is_flash_attn_supported(npu_device_tensor_op op, const npu_device_tensor_sp
295305
296306 const auto * v = &srcs[2 ];
297307 if (v->type != k->type ) { // TODO: support more v types
298- DEVICE_LOG_DEBUG (" [%s]v type is not the same as k: %s vs %s\n " , op_get_name (op), get_type_name (v->type ),
308+ DEVICE_LOG_DEBUG (" [%s]v type is not the same as k: %s vs %s\n " ,
309+ op_get_name (op),
310+ get_type_name (v->type ),
299311 get_type_name (k->type ));
300312 return false ;
301313 }
@@ -310,28 +322,42 @@ bool is_flash_attn_supported(npu_device_tensor_op op, const npu_device_tensor_sp
310322 DEVICE_LOG_DEBUG (
311323 " [%s]dst shape does not match q and v: dst ne: %ld, %ld, %ld, %ld, q ne: %ld, %ld, %ld, %ld, "
312324 " v ne: %ld, %ld, %ld, %ld\n " ,
313- op_get_name (op), dst->ne [0 ], dst->ne [1 ], dst->ne [2 ], dst->ne [3 ], q->ne [0 ], q->ne [1 ], q->ne [2 ], q->ne [3 ],
314- v->ne [0 ], v->ne [1 ], v->ne [2 ], v->ne [3 ]);
325+ op_get_name (op),
326+ dst->ne [0 ],
327+ dst->ne [1 ],
328+ dst->ne [2 ],
329+ dst->ne [3 ],
330+ q->ne [0 ],
331+ q->ne [1 ],
332+ q->ne [2 ],
333+ q->ne [3 ],
334+ v->ne [0 ],
335+ v->ne [1 ],
336+ v->ne [2 ],
337+ v->ne [3 ]);
315338 return false ;
316339 }
317340
318341 if (is_transposed_or_permuted (dst->nb )) {
319- DEVICE_LOG_DEBUG (" [%s]dst cannot be transposed or permuted, nb: %zu, %zu, %zu, %zu\n " , op_get_name (op),
320- dst->nb [0 ], dst->nb [1 ], dst->nb [2 ], dst->nb [3 ]);
342+ DEVICE_LOG_DEBUG (" [%s]dst cannot be transposed or permuted, nb: %zu, %zu, %zu, %zu\n " ,
343+ op_get_name (op),
344+ dst->nb [0 ],
345+ dst->nb [1 ],
346+ dst->nb [2 ],
347+ dst->nb [3 ]);
321348 return false ;
322349 }
323350
324351 if (q->ne [0 ] != k->ne [0 ]) {
325352 DEVICE_LOG_DEBUG (" [%s]q and k shapes do not match: q ne: %ld, %ld, %ld, %ld, k ne: %ld, %ld, %ld, %ld\n " ,
326- op_get_name (op), q->ne [0 ], q->ne [1 ], q->ne [2 ], q->ne [3 ], k->ne [0 ], k->ne [1 ], k->ne [2 ],
327- k->ne [3 ]);
328- return false ;
329- }
330-
331- if (q->ne [2 ] != k->ne [2 ] || q->ne [3 ] != k->ne [3 ] || q->ne [3 ] != 1 ) {
332- // TODO: add broadcast support
333- DEVICE_LOG_DEBUG (" [%s]q and k shapes do not match: q ne: %ld, %ld, %ld, %ld, k ne: %ld, %ld, %ld, %ld\n " ,
334- op_get_name (op), q->ne [0 ], q->ne [1 ], q->ne [2 ], q->ne [3 ], k->ne [0 ], k->ne [1 ], k->ne [2 ],
353+ op_get_name (op),
354+ q->ne [0 ],
355+ q->ne [1 ],
356+ q->ne [2 ],
357+ q->ne [3 ],
358+ k->ne [0 ],
359+ k->ne [1 ],
360+ k->ne [2 ],
335361 k->ne [3 ]);
336362 return false ;
337363 }
0 commit comments