Skip to content

Commit e6a5f7b

Browse files
authored
feat: perf opt add set rows (#59)
* Add power management utilities to NPU device context and update DCVS settings * Update DCVS settings in power_utils to use v3 API and enhance power management * wip * Enhance dequantization functions by adding load_dequant_table support and updating signatures for improved performance * use lut * wip * fix test failure * wip * Refactor load_qual_block_generic to improve block handling and optimize vector operations * Enhance load_dual_block_generic and load_qual_block_generic to accept a mask parameter for improved block handling * Refactor flash_attn_impl to optimize mask l2 prefetch * wip * wip * wip * wip * add log * link against shared libraries instead of static ones * fix swiglu * wip * refactor expf_fix to handle overflow for different data types * enhance is_glu_op_supported to validate shapes for multiple sources * wip * refactor logging macros to use hexagon namespace and improve formatting * fix printf format error * wip * refactor: update static_assert messages for block size validation and add HVX_VectorPred_x3 type alias * rename * feat: enhance fa with mask * wip * wip * refactor: replace instances of Q6_V_vzero() with kZeroV for consistency * wip * wip * wip * fix: improve address alignment check in HVX_Vector handling * refactor: streamline vector dot product implementations for improved readability * refactor: q4k add hvx intrinsic impl * refactor: enhance dequantize_row_q4_K for clarity and performance * refactor: optimize scale mask usage in dequantization functions for improved performance * refactor: optimize dequantize_row_q4_K for intrinsic usage and performance improvements * refactor: move GLU operation implementation into separated file * sync after swiglu * wip * wip * wip * feat: increase prc main thread stack size * fix: replace hardcoded stack size with NPU_THREAD_STACK_SIZE constant * wip * feat: add optimized vector operations for exponential and division with overflow handling * wip * feat: refactor exponential function to handle overflow and underflow with improved logic * wip * wip * feat: add vector loading and scaling functions for improved performance in block processing * wip * feat: optimize block loading by refactoring scale index handling for improved performance * use Q6_Vb_vlut32_VbVbR_nomatch instead * feat: enhance scale loading by adding static assertion and restructuring block handling * wip * feat: refactor vec_dot_product_mixed_impl for improved clarity and performance * wip * feat: simplify vector loading functions and improve alignment handling * wip * feat: enhance scale loading mask with quantization block size validation * wip * feat: implement make_scale_load_mask function and refactor vector handling in vec_ops * feat: enhance load_dual_block_generic to include scale indices for improved vector loading * revert q8 dequant * wip * feat: optimize dequantization functions by removing unnecessary masking and updating lookup methods * wip * wip * add qurt_mutex * Add DMA transfer class and integrate into thread pool * Enhance DMA transfer functionality by adding support for multiple descriptors and initiating transfers in parallel * fix dma crash * fix failed unit tests * wip * use alignas * Improve DMA transfer error handling and update descriptor completion check * Fix VTCM cache size calculation in element-wise operations * Add cache clean operations before DMA transfers in element-wise operations * reduce cache clean operations * Refactor DMA transfer functions to support 1D operations and rename for clarity * Enhance DMA transfer functionality by adding 2D submission support and improving descriptor initialization * Update read buffer method to support forced invalidation and remove unnecessary invalidation calls in element-wise operations * wip * Improve DMA transfer handling in mul_mat_gemv_impl by replacing memcpy with initiate_dma_row_transfer and adding wait_for_dma logic * fix 2d dma * feat: add DMA plane cache * rename * wip * use memcpy for debug * fix cache plane calc * refactor: remove debug logging from mul_mat_impl and optimize cache handling * rename * fix 2d dma type * refactor: enhance DMA transfer handling in mul_mat_gemv_impl and wait functions * refactor: optimize DMA transfer handling in mul_mat_gemv_impl and wait functions * wip * wip * move op impl into sub dir * add log * fix: correct pointer usage in mul_mat_gemv_impl for next plane access * fix: improve DMA transfer error handling in mul_mat_impl and mul_mat_gemv_impl * fix: fix crash by using the entire row bytes * wip * wip * fix: prevent parallelization for scalar src1 in is_mul_mat_supported * fix: add dimension checks for 2D DMA transfers and fallback to 1D if necessary * wip * fix: enable thread barrier for mul multiplication operations * feat: add synchronization checks for tensor operations and update related functions * wip * fix: remove invalidation flag from get_read_buffer calls in element-wise and matrix multiplication operations * Revert "fix: remove invalidation flag from get_read_buffer calls in element-wise and matrix multiplication operations" This reverts commit af3441e. * wip * wip * add comment * fix: improve DMA transfer handling in mul_mat_gemv_impl for quantized source tensors * add log * try fix mulmat gemv * wip * fix: enhance DMA transfer handling in mul_mat_gemv_impl for quantized source tensors * fix: optimize cache offset calculation and remove redundant swap in mul_mat_gemv_impl * fix: refactor DMA transfer handling in mul_mat_gemv_impl for improved clarity and maintainability * wip * wip * wip * fix: enhance mul_mat_impl for improved cache handling and clarity * fix: refactor tensor unflattening and DMA transfer initialization for improved clarity and type safety * fix: improve cache handling of quant * wip * fix: improve cache handling in mul_mat_impl and mul_mat_gemv_impl for better memory efficiency * rename * add load_hexa_block_generic * wip * extract dequant block into separated function * refactor: enhance dequantization functions with table parameter * fix load_dual_block_generic * refactor: rename dequantization functions for clarity and enhance block handling * refactor: simplify dequantization logic by consolidating block handling and removing unused parameters * wip * wip * feat: add make_qs_load_mask function and update load_dual_block_generic to use qs_indices * fix load_dual_block_generic * refactor: update load functions to use qs_indices for improved block loading * wip * fix: update loop indices and boundary checks to use size_t for better efficiency * wip * update make_scale_load_mask, to make it available for q8 * feat: add vec_dot_product_quant_impl for quantized dot product computation * refactoring: move come quant func to dedicated file * refactor: rename dequantization functions for clarity and consistency * wip * feat: enhance vec_dot_product_quant_impl with dual dequantization and improved assertions * add vec_dot_product_vqf32_q40_f32 * wip * wip * wip * wip * implement vec_mpy_qf32_qf32_qf32 function and update vec_dot_product_vqf32_q40_f32 to use it * wip * add src0_plane_write_cache_offset * wip * enhance mul_mat_f32 to handle NPU_DATA_TYPE_Q4_0 for quantized matrix multiplication * wip * wip * update test func * refactor mul_mat_gemv_quant_impl to use get_nb for row stride and remove unused test function in init_f16_f32_table * wip * Add support for 4-block dequantization in vec_quant and update dot product implementation * Refactor vec_dot_product_quant_impl to improve variable handling and enhance readability * Refactor vec_dot_product_quant_impl to replace template function with inline vector operations * use Q6_Vqf32_vmpy_VsfVsf instead of Q6_Vqf32_vmpy_Vqf32Vqf32 * Revert "use Q6_Vqf32_vmpy_VsfVsf instead of Q6_Vqf32_vmpy_Vqf32Vqf32" This reverts commit 5483916. * wip * improve log print in graph * Refactor batched_row_dot to accept additional arguments and remove batched_row_dot_with_table * Refactor synchronization functions to include previous operation and NE type parameters * Refactor synchronization checks in several operations * Update synchronization checks to include NPU_OP_COUNT in required conditions * Add performance tracking to buffer management functions * add memset * add log * fix: update backend device type from ACCEL to IGPU * fix comment * add get/set rows * feat: implement row operation support checks in is_rows_supported * feat: add support for I64 data type in rows operations * feat: implement set_rows functionality for I32 and I64 data types * wip * fix set_rows * feat: extend is_rows_supported to allow F32 data type in destination * wip * feat: rename set_rows function, add generic to its name * disable q4_k * move ops to separated file * rename: op_impl -> op_registry * refactor: update get_data_type struct to include output type for unary operations * refactor: simplify vec_trans_impl by removing parameterized overload and using variadic templates * add vec_trans_with_half_ret_impl * add NPU_OP_CPY * refactor: enhance is_unary_op_supported to handle non-continuous rows and add type support logging * refactor: update vec_trans_with_half_ret_impl to use processed_bytes for clarity and accuracy * wip * refactor: optimize dequantize_vec_q40_qf32_4blocks by improving shuffling logic and reducing redundancy * refactor: improve performance of vec_dot_product and dequantize functions by optimizing shuffling logic * wip * add dequantize_vec_q40_qf32_6blocks * feat: add load_dequant_vec_q40_qf32_6blocks function for 6-block dequantization * feat: enhance vec_dot_product_quant_impl with 6-element processing loop for improved performance * Revert "feat: enhance vec_dot_product_quant_impl with 6-element processing loop for improved performance" This reverts commit a5c8fa3. since there's a performance degradation * fix: correct load_hexa_block_generic return type and update dequantization logic * wip * wip * feat: add make_q40_qs_load_mask function and update vec_dot_product_vqf32_q40_f32 * fix dequant load * add debug log * wip * wip * fix shuffle index array * refactor: simplify load mask generation and improve index shuffling for q4 blocks * wip * wip * fix comment * wip * update ops.md * update ops.md by create_ops_docs.py # Conflicts: # docs/ops.md
1 parent 38ae191 commit e6a5f7b

File tree

19 files changed

+18425
-499
lines changed

19 files changed

+18425
-499
lines changed

docs/ops.md

Lines changed: 102 additions & 102 deletions
Large diffs are not rendered by default.

docs/ops/hexagon-npu.csv

Lines changed: 17663 additions & 0 deletions
Large diffs are not rendered by default.

ggml/src/ggml-qnn/npu/device/device.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11

22
#include "graph.hpp"
33
#include "hexagon_npu.h"
4-
#include "op_impl.hpp"
4+
#include "op_registry.hpp"
55
#include "remote.h"
66
#include "tensor.hpp"
77
#include "thread_pool.hpp"

ggml/src/ggml-qnn/npu/device/graph.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11

22
#include "graph.hpp"
33

4-
#include "op_impl.hpp"
4+
#include "op_registry.hpp"
55
#include "util.hpp"
66
#include "vtcm_mem.hpp"
77

ggml/src/ggml-qnn/npu/device/op/op_impl.cpp renamed to ggml/src/ggml-qnn/npu/device/op/op_eltwise.hpp

Lines changed: 56 additions & 159 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,10 @@
1+
#pragma once
12

2-
3-
#include "op_impl.hpp"
4-
5-
#include "op_flash_attn.hpp"
6-
#include "op_glu.hpp"
7-
#include "op_mul_mat.hpp"
8-
#include "op_rope.hpp"
3+
#include "op_types.hpp"
94
#include "type_traits.hpp"
105
#include "vec_ops.hpp"
116

12-
#include <cmath>
13-
#include <type_traits>
14-
15-
namespace {
7+
namespace hexagon {
168

179
template <HVX_Vector (*_OpBinaryTransform)(HVX_Vector, HVX_Vector)>
1810
inline void vec_op_f32_f32(const float * src0, const float * src1, float * dst, size_t count) {
@@ -41,6 +33,14 @@ inline void vec_op_f16_f16(const npu_device_fp16_t * src0,
4133
vec_trans_impl<_OpBinaryTransform, npu_device_fp16_t>(src0, src1, dst, count);
4234
}
4335

36+
template <HVX_Vector (*_OpUnaryTransform)(HVX_VectorPair)>
37+
inline void unary_vec_op_f16_f32(const float * src, npu_device_fp16_t * dst, size_t count, size_t) {
38+
// TODO: remove the unused param
39+
40+
using namespace hexagon::vec;
41+
vec_trans_with_half_ret_impl<_OpUnaryTransform, float, npu_device_fp16_t>(src, dst, count);
42+
}
43+
4444
inline HVX_Vector vadd_f16_f16(HVX_Vector a, HVX_Vector b) {
4545
// TODO: fix this since qf16 has less precision than fp16
4646
return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_VhfVhf(a, b));
@@ -55,16 +55,25 @@ inline HVX_Vector vmul_f16_f16(HVX_Vector a, HVX_Vector b) {
5555
return Q6_Vhf_equals_Wqf32(Q6_Wqf32_vmpy_VhfVhf(a, b));
5656
}
5757

58+
inline HVX_Vector vequals_f16_f32(HVX_VectorPair a) {
59+
const HVX_Vector kZeroV = Q6_V_vzero();
60+
HVX_Vector lo = Q6_Vqf32_vadd_Vqf32Vsf(kZeroV, Q6_V_lo_W(a));
61+
HVX_Vector hi = Q6_Vqf32_vadd_Vqf32Vsf(kZeroV, Q6_V_hi_W(a));
62+
a = Q6_W_vcombine_VV(hi, lo);
63+
return Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(a));
64+
}
65+
5866
template <typename T> struct get_data_type {};
5967

6068
template <typename _TyData> struct get_data_type<void (*)(const _TyData *, const _TyData *, _TyData *, size_t)> {
6169
using type = _TyData;
6270
};
6371

64-
template <typename _TyData, typename _TyParam>
65-
struct get_data_type<void (*)(const _TyData *, _TyData *, size_t, _TyParam)> {
66-
using type = _TyData;
67-
using param_type = typename std::remove_cv<typename std::remove_reference<_TyParam>::type>::type;
72+
template <typename _TyInput, typename _TyOutput, typename _TyParam>
73+
struct get_data_type<void (*)(const _TyInput *, _TyOutput *, size_t, _TyParam)> {
74+
using type = _TyInput;
75+
using output_type = _TyOutput;
76+
using param_type = typename std::remove_cv<typename std::remove_reference<_TyParam>::type>::type;
6877
};
6978

7079
template <auto _RowFunc> bool element_wise_op(hexagon::tensor * out, hexagon::compute_params * params) {
@@ -280,8 +289,9 @@ void rms_norm_vec_f32(const float * src, float * dst, size_t count, float eps) {
280289

281290
// TODO: merge with element_wise_op?
282291
template <auto _RowFunc> bool unary_op(hexagon::tensor * out, hexagon::compute_params * params) {
283-
using data_type = typename get_data_type<decltype(_RowFunc)>::type;
284-
using param_type = typename get_data_type<decltype(_RowFunc)>::param_type;
292+
using input_type = typename get_data_type<decltype(_RowFunc)>::type;
293+
using output_type = typename get_data_type<decltype(_RowFunc)>::output_type;
294+
using param_type = typename get_data_type<decltype(_RowFunc)>::param_type;
285295

286296
if (!out) {
287297
return false;
@@ -311,7 +321,7 @@ template <auto _RowFunc> bool unary_op(hexagon::tensor * out, hexagon::compute_p
311321
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER(out, params->get_thread_index());
312322

313323
const auto param = out->get_op_param<param_type>(0);
314-
const size_t valid_row_bytes = src0->get_ne(0) * sizeof(data_type);
324+
const size_t valid_row_bytes = src0->get_ne(0) * sizeof(input_type);
315325
for (int64_t ir = start_end.first; ir < start_end.second; ++ir) {
316326
const auto i03 = ir / rows_per_cube;
317327
const auto i02 = ir / out->get_ne(1) - i03 * out->get_ne(2);
@@ -323,7 +333,7 @@ template <auto _RowFunc> bool unary_op(hexagon::tensor * out, hexagon::compute_p
323333
hexagon::l2fetch_row(src0_row + src0->get_nb(1), valid_row_bytes);
324334
}
325335

326-
_RowFunc(reinterpret_cast<const data_type *>(src0_row), reinterpret_cast<data_type *>(dst_row),
336+
_RowFunc(reinterpret_cast<const input_type *>(src0_row), reinterpret_cast<output_type *>(dst_row),
327337
static_cast<size_t>(out->get_ne(0)), param);
328338
}
329339

@@ -336,7 +346,7 @@ bool is_unary_op_supported(const npu_device_tensor_op_spec * op_spec,
336346
const npu_device_tensor_spec * srcs,
337347
size_t src_len) {
338348
const auto op = op_spec->op;
339-
if (op != NPU_OP_RMS_NORM) {
349+
if (op != NPU_OP_RMS_NORM && op != NPU_OP_CPY) {
340350
DEVICE_LOG_DEBUG("[%s]unsupported\n", hexagon::op_get_name(op));
341351
return false;
342352
}
@@ -347,21 +357,36 @@ bool is_unary_op_supported(const npu_device_tensor_op_spec * op_spec,
347357
}
348358

349359
const auto & src0 = srcs[0];
350-
if (dst->type != src0.type) {
351-
DEVICE_LOG_DEBUG("[%s]src0.type and dst.type mismatch: %s vs %s\n", hexagon::op_get_name(op),
352-
hexagon::get_type_name(src0.type), hexagon::get_type_name(dst->type));
360+
if (!hexagon::is_same_shape(src0, *dst)) {
361+
DEVICE_LOG_DEBUG("[%s]src0 and dst have different shape\n", hexagon::op_get_name(op));
353362
return false;
354363
}
355364

356-
if (dst->type != NPU_DATA_TYPE_F32) {
357-
DEVICE_LOG_DEBUG("[%s]unsupported data type: %s\n", hexagon::op_get_name(op),
358-
hexagon::get_type_name(dst->type));
359-
return false;
360-
}
365+
if (op == NPU_OP_RMS_NORM) {
366+
if (dst->type != src0.type) {
367+
DEVICE_LOG_DEBUG("[%s]src0.type and dst.type mismatch: %s vs %s\n", hexagon::op_get_name(op),
368+
hexagon::get_type_name(src0.type), hexagon::get_type_name(dst->type));
369+
return false;
370+
}
361371

362-
if (!hexagon::is_same_shape(src0, *dst)) {
363-
DEVICE_LOG_DEBUG("[%s]src0 and dst have different shape\n", hexagon::op_get_name(op));
364-
return false;
372+
if (dst->type != NPU_DATA_TYPE_F32) {
373+
DEVICE_LOG_DEBUG("[%s]unsupported data type: %s\n", hexagon::op_get_name(op),
374+
hexagon::get_type_name(dst->type));
375+
return false;
376+
}
377+
} else {
378+
if (dst->nb[1] < dst->nb[0] || src0.nb[1] < src0.nb[0]) {
379+
// TODO: support non-continuous row
380+
DEVICE_LOG_DEBUG("[%s]unsupported non-continuous row\n", hexagon::op_get_name(op));
381+
return false;
382+
}
383+
384+
if (dst->type != NPU_DATA_TYPE_F16 || src0.type != NPU_DATA_TYPE_F32) {
385+
// TODO: support more types
386+
DEVICE_LOG_DEBUG("[%s]unsupported data type src:%s dst:%s\n", hexagon::op_get_name(op),
387+
hexagon::get_type_name(src0.type), hexagon::get_type_name(dst->type));
388+
return false;
389+
}
365390
}
366391

367392
return true;
@@ -378,132 +403,4 @@ bool is_unary_op_required_sync(npu_device_tensor_op prev_op,
378403
prev_op != NPU_OP_COUNT;
379404
}
380405

381-
struct op_capabilities {
382-
npu_device_tensor_op op;
383-
hexagon::op_is_supported_func_type is_supported;
384-
hexagon::op_required_sync_func_type requires_thread_barrier_func;
385-
hexagon::compute_func_type compute_funcs[NPU_DATA_TYPE_COUNT];
386-
};
387-
388-
constexpr const op_capabilities kOpCapabilities[] = {
389-
{
390-
NPU_OP_MUL_MAT, hexagon::is_mul_mat_supported,
391-
hexagon::is_mul_mat_required_sync,
392-
{
393-
hexagon::mul_mat_f32, // NPU_DATA_TYPE_F32
394-
nullptr, // NPU_DATA_TYPE_F16
395-
}, },
396-
{
397-
NPU_OP_ADD, is_element_wise_op_supported,
398-
is_element_wise_op_required_sync, {
399-
element_wise_op<vec_op_f32_f32<vadd_f32_f32>>, // NPU_DATA_TYPE_F32
400-
element_wise_op<vec_op_f16_f16<vadd_f16_f16>>, // NPU_DATA_TYPE_F16
401-
}, },
402-
{
403-
NPU_OP_SUB, is_element_wise_op_supported,
404-
is_element_wise_op_required_sync, {
405-
element_wise_op<vec_op_f32_f32<vsub_f32_f32>>, // NPU_DATA_TYPE_F32
406-
element_wise_op<vec_op_f16_f16<vsub_f16_f16>>, // NPU_DATA_TYPE_F16
407-
}, },
408-
{
409-
NPU_OP_MUL, is_element_wise_op_supported,
410-
is_element_wise_op_required_sync, {
411-
element_wise_op<vec_op_f32_f32<vmul_f32_f32>>, // NPU_DATA_TYPE_F32
412-
element_wise_op<vec_op_f16_f16<vmul_f16_f16>>, // NPU_DATA_TYPE_F16
413-
}, },
414-
{
415-
NPU_OP_RMS_NORM, is_unary_op_supported,
416-
is_unary_op_required_sync, {
417-
unary_op<rms_norm_vec_f32>, // NPU_DATA_TYPE_F32
418-
nullptr, // NPU_DATA_TYPE_F16
419-
}, },
420-
{
421-
NPU_OP_FLASH_ATTN, hexagon::is_flash_attn_supported,
422-
hexagon::is_flash_attn_required_sync,
423-
{
424-
hexagon::flash_attn_f32, // NPU_DATA_TYPE_F32
425-
nullptr, // NPU_DATA_TYPE_F16
426-
}, },
427-
{
428-
NPU_OP_ROPE, hexagon::is_rope_supported,
429-
hexagon::is_rope_required_sync,
430-
{
431-
hexagon::rope_f32, // NPU_DATA_TYPE_F32
432-
nullptr, // NPU_DATA_TYPE_F16
433-
}, },
434-
{
435-
NPU_OP_GLU, hexagon::is_glu_op_supported,
436-
hexagon::is_glu_required_sync,
437-
{
438-
hexagon::glu_f32, // NPU_DATA_TYPE_F32
439-
hexagon::glu_f16, // NPU_DATA_TYPE_F16
440-
}, },
441-
};
442-
443-
static_assert(kOpCapabilities[NPU_OP_MUL_MAT].compute_funcs[NPU_DATA_TYPE_F32] == hexagon::mul_mat_f32,
444-
"kOpArray[NPU_OP_MUL_MAT] != mul_mat_f32");
445-
446-
static_assert(std::size(kOpCapabilities) == NPU_OP_COUNT);
447-
static_assert(kOpCapabilities[NPU_OP_MUL_MAT].op == NPU_OP_MUL_MAT, "kOpArray[NPU_OP_MUL_MAT].op != NPU_OP_MUL_MAT");
448-
static_assert(kOpCapabilities[NPU_OP_MUL].op == NPU_OP_MUL, "kOpArray[NPU_OP_MUL].op != NPU_OP_MUL");
449-
static_assert(kOpCapabilities[NPU_OP_RMS_NORM].op == NPU_OP_RMS_NORM,
450-
"kOpArray[NPU_OP_RMS_NORM].op != NPU_OP_RMS_NORM");
451-
static_assert(kOpCapabilities[NPU_OP_FLASH_ATTN].op == NPU_OP_FLASH_ATTN,
452-
"kOpArray[NPU_OP_FLASH_ATTN].op != NPU_OP_FLASH_ATTN");
453-
static_assert(kOpCapabilities[NPU_OP_ROPE].op == NPU_OP_ROPE, "kOpArray[NPU_OP_ROPE].op != NPU_OP_ROPE");
454-
static_assert(kOpCapabilities[NPU_OP_GLU].op == NPU_OP_GLU, "kOpArray[NPU_OP_GLU].op != NPU_OP_GLU");
455-
456-
hexagon::compute_func_type get_compute_func_impl(npu_device_tensor_op op, npu_device_tensor_data_type type) {
457-
if (op >= NPU_OP_COUNT) {
458-
return nullptr;
459-
}
460-
461-
return kOpCapabilities[op].compute_funcs[type];
462-
}
463-
464-
} // namespace
465-
466-
namespace hexagon {
467-
468-
compute_func_type get_compute_func(tensor * dst) {
469-
return get_compute_func_impl(dst->get_op(), dst->get_type());
470-
}
471-
472-
bool requires_thread_barrier(npu_device_tensor_op prev_op,
473-
const npu_device_ne_type & prev_ne,
474-
npu_device_tensor_op op,
475-
const npu_device_ne_type & ne) {
476-
if (op >= NPU_OP_COUNT) {
477-
return false;
478-
}
479-
480-
auto requires_thread_barrier_func = kOpCapabilities[op].requires_thread_barrier_func;
481-
return requires_thread_barrier_func && requires_thread_barrier_func(prev_op, prev_ne, op, ne);
482-
}
483-
484-
bool support_op(const npu_device_tensor_op_spec * op_spec,
485-
const npu_device_tensor_spec * dst,
486-
const npu_device_tensor_spec * srcs,
487-
size_t src_len) {
488-
if (!op_spec) {
489-
DEVICE_LOG_ERROR("[hexagon-npu]invalid op_spec\n");
490-
return false;
491-
}
492-
493-
const auto op = op_spec->op;
494-
auto is_supported_func = kOpCapabilities[op].is_supported;
495-
if (!is_supported_func || !is_supported_func(op_spec, dst, srcs, src_len)) {
496-
DEVICE_LOG_DEBUG("[%s]unsupported, is_supported_func return false\n", op_get_name(op));
497-
return false;
498-
}
499-
500-
if (get_compute_func_impl(op, dst->type) == nullptr) {
501-
DEVICE_LOG_DEBUG("[%s]unsupported, get_compute_func failed, type: %s\n", op_get_name(op),
502-
get_type_name(dst->type));
503-
return false;
504-
}
505-
506-
return true;
507-
}
508-
509406
} // namespace hexagon

ggml/src/ggml-qnn/npu/device/op/op_flash_attn.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,10 @@ void flash_attn_impl(hexagon::tensor * out,
5858
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
5959

6060
const auto & k_type_traits = hexagon::get_type_traits(kKvDataType);
61-
const auto q_to_vec_dot = k_type_traits.from_float;
61+
const auto q_to_kv_type = k_type_traits.from_float;
6262
constexpr const auto kq_vec_dot = _IsKvF16 ? hexagon::type_erase_dot_func<hexagon::vec_dot_product_f16_f16> :
6363
hexagon::type_erase_dot_func<hexagon::vec_dot_product_f32_f32>;
64-
if (!q_to_vec_dot) {
64+
if (!q_to_kv_type) {
6565
DEVICE_LOG_ERROR("flash_attn_impl: unsupported data type for q, k, or v\n");
6666
return;
6767
}
@@ -134,7 +134,7 @@ void flash_attn_impl(hexagon::tensor * out,
134134
(iq3 % mask->get_ne(3)) * mask->get_nb(3)) :
135135
nullptr;
136136

137-
q_to_vec_dot(reinterpret_cast<const float *>(q_data), Q_q, DK);
137+
q_to_kv_type(reinterpret_cast<const float *>(q_data), Q_q, DK);
138138

139139
if (kHasMask) {
140140
hexagon::l2fetch_row(reinterpret_cast<const uint8_t *>(mp), mask->get_nb(1));

ggml/src/ggml-qnn/npu/device/op/op_glu.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,7 @@ inline void glu_vec_op_f32_f32(const float * src0,
4848
size_t count,
4949
hexagon::HVX_VectorPair_x4 coeff) {
5050
using namespace hexagon::vec;
51-
vec_trans_with_param_impl<float, hexagon::HVX_VectorPair_x4, hexagon::vec_swiglu_f32_f32>(src0, src1, dst, count,
52-
coeff);
51+
vec_trans_impl<hexagon::vec_swiglu_f32_f32, float, hexagon::HVX_VectorPair_x4>(src0, src1, dst, count, coeff);
5352
}
5453

5554
template <auto _GluRowFunc, auto _CoeffLoadFunc>

0 commit comments

Comments
 (0)