Skip to content

Commit 4a3a874

Browse files
committed
disable broadcast on flash_attn_ext
1 parent 560729e commit 4a3a874

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

ggml/src/ggml-qnn/npu/device/op_flash_attn.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,14 @@ bool is_flash_attn_supported(npu_device_tensor_op op, const npu_device_tensor_sp
312312
return false;
313313
}
314314

315+
if (q->ne[2] != k->ne[2] || q->ne[3] != k->ne[3] || q->ne[3] != 1) {
316+
// TODO: add broadcast support
317+
DEVICE_LOG_DEBUG("[%s]q and k shapes do not match: q ne: %ld, %ld, %ld, %ld, k ne: %ld, %ld, %ld, %ld\n",
318+
op_get_name(op), q->ne[0], q->ne[1], q->ne[2], q->ne[3], k->ne[0], k->ne[1], k->ne[2],
319+
k->ne[3]);
320+
return false;
321+
}
322+
315323
return true;
316324
}
317325

0 commit comments

Comments
 (0)