Skip to content

Commit a29243e

Browse files
authored
feat: perf opt quant (#47)
* feat: add mixed precision dot product implementation and function declaration * feat: implement mixed precision vector dot product and conversion functions * fix: update data type handling in matrix multiplication implementation * fix: adjust row count handling in matrix multiplication implementation for accurate slicing * fix: optimize matrix multiplication implementation by unroll loop * update performance tracking for matrix multiplication implementation * add fetching * wip * fix: support F16 * F32 multiplication in is_mul_mat_supported function * fix: improve src0 fetching logic in vec_dot_product_mixed_impl for better alignment handling * fix test failure for row width 67 * try fix failed test * fix: rename aligned_address to align_down for clarity in vector alignment handling * wip * qnn fix: update device capabilities for quantized types in qnn-lib to improve compatibility * fix test failure at width == 193 * fix: replace zero vector initialization with previous vector in mixed dot product implementation * wip * fix: improve handling of last vector in mixed dot product implementation * wip * wip * wip * wip * Enhance mul_mat_f32 function to support quantized types and improve static assertions * rename * Refactor dequantization functions to use npu_device_fp16_t and improve type handling * Optimize dequantization in dequantize_row_q8_0 by replacing qf32 multiplication with qf16 * Optimize dequantization in dequantize_row_q4_0 by replacing qf32 multiplication with qf16 * Add hvx_vsf_convert_vhf function for improved vector conversion * add perf logs * Refactor dequantize_row_q4_0 for alignment * Update logging in supports_op_impl and supports_op to use ggml_op_desc for better clarity * Add support for ROPE operation in NPU capabilities and related functions * Implement ROPE operation in tensor and op_rope, including cache initialization and correction dimension calculations * enable ROPE by adding operation validation * add support to freq is null case * wip * Refactor rope_f32 to improve indexing by introducing total_planes calculation * reformat * Refactor rope_f32 to optimize data access patterns by introducing row and plane pointers * Add performance tracking to rope_f32 function for enhanced profiling * Refactor rope_f32 to use a templated implementation * Refactor rope_impl to replace loop with memcpy for improved performance * Refactor mul_mat_impl to support quantization as a template parameter * wip * wip * Refactor rope_impl to optimize plane indexing in the processing loop * Add aligned vector dot product implementation for mixed precision types * wip * Enhance matrix multiplication for F32 and F16 types with alignment checks * Optimize vec_dot_product_mix_aligned_impl for improved performance with additional vector sums * Add alignment checks for matrix multiplication and vector dot products * Refactor matrix multiplication to use function pointers for improved readability and maintainability * Fix alignment check in is_dot_product_aligned to ensure correct vector size handling * Remove unused f16_to_f32_table parameter from quantization and dequantization functions * wip * Add L2 fetch for src1 plane rows in matrix multiplication implementation * wip * Refactor hvx_vsf_convert_vhf to accept an additional parameter for flexibility in vector multiplication * Refactor vec_dot_product_mix_aligned_impl to improve variable naming for clarity * Refactor load_dual_block_generic and dequantize_row_q4_0 to improve performance * Refactor vector operation functions to improve clarity and consistency in variable usage * wip * wip * Refactor dequantize_row_q4_0_impl for improved clarity and performance in vector operations * wip * Update load_dual_block_generic to use intrinsics * Refactor load_dual_block_generic and load_qual_block_generic for improved performance and clarity * wip * wip * Optimize dequantize_row_q8_0 for improved performance by unrolling for loop * wip * wip * fix typo
1 parent 989772c commit a29243e

File tree

17 files changed

+1096
-199
lines changed

17 files changed

+1096
-199
lines changed

ggml/src/ggml-qnn/npu/device/op_flash_attn.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
3939

4040
const auto q_to_vec_dot = hexagon::get_type_traits(k->get_type()).from_float; // TODO: fix this
4141
const auto kq_vec_dot = hexagon::get_type_traits(k->get_type()).vec_dot;
42-
const auto v_to_float = hexagon::get_type_traits(v->get_type()).to_float;
4342
if (!q_to_vec_dot || !kq_vec_dot) {
4443
DEVICE_LOG_ERROR("flash_attn_impl: unsupported data type for q, k, or v\n");
4544
return;
@@ -95,7 +94,6 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
9594
float M = -INFINITY; // maximum KQ value
9695

9796
float * VKQ32 = reinterpret_cast<float *>(cache_ptr); // FP32 VKQ accumulator
98-
float * V32 = VKQ32 + aligned_dv; // (temporary) FP32 V buffer
9997
auto * VKQ16 = reinterpret_cast<npu_device_fp16_t *>(VKQ32 + aligned_dv); // (temporary) FP16 VKQ accumulator
10098
auto * Q_q = reinterpret_cast<npu_device_fp16_t *>(
10199
VKQ32 + 2 * aligned_dv); // (temporary) buffer for Q converted to quantized/FP16
@@ -122,7 +120,7 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
122120
hexagon::l2fetch_row(q_data + q->get_nb(1), row_bytes_q);
123121
}
124122

125-
q_to_vec_dot(reinterpret_cast<const float *>(q_data), Q_q, DK, params->f16_to_f32_table);
123+
q_to_vec_dot(reinterpret_cast<const float *>(q_data), Q_q, DK);
126124

127125
// online softmax / attention
128126
// loop over n_kv and n_head_kv
@@ -192,10 +190,7 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
192190

193191
// V += v*expf(s - M)
194192
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(flash_attn, 2, mad);
195-
if (v_to_float) {
196-
v_to_float(v_data, V32, DV, params->f16_to_f32_table);
197-
hexagon::vec_mad_f32(V32, vs, VKQ32, DV);
198-
} else {
193+
{
199194
// V is F32
200195
hexagon::vec_mad_f32(reinterpret_cast<const float *>(v_data), vs, VKQ32, DV);
201196
}

ggml/src/ggml-qnn/npu/device/op_impl.cpp

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include "op_flash_attn.hpp"
88
#include "op_mul_mat.hpp"
9+
#include "op_rope.hpp"
910
#include "type_traits.hpp"
1011
#include "vec_ops.hpp"
1112

@@ -62,7 +63,7 @@ inline void vec_op_impl(const _TyData * src0, const _TyData * src1, size_t count
6263
(leftover_bytes + hexagon::unaligned_bytes(iptr1) > hexagon::kBytesPerVector) ? *iptr1 : prev1;
6364
curr1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
6465

65-
q6op_vstu_variable_ARV(optr, leftover_bytes, _OpIntrinsic(curr0, curr1));
66+
hexagon::q6op_vstu_variable_ARV(optr, leftover_bytes, _OpIntrinsic(curr0, curr1));
6667
}
6768
}
6869

@@ -179,16 +180,6 @@ template <auto _RowFunc> bool element_wise_op(hexagon::tensor * out, hexagon::co
179180
return true;
180181
}
181182

182-
bool is_same_shape(const npu_device_tensor_spec & src, const npu_device_tensor_spec & dst) {
183-
for (size_t i = 0; i < DEVICE_TENSOR_MAX_DIMS; ++i) {
184-
if (src.ne[i] != dst.ne[i]) {
185-
return false;
186-
}
187-
}
188-
189-
return true;
190-
}
191-
192183
bool is_element_wise_op_supported(npu_device_tensor_op op, const npu_device_tensor_spec * dst,
193184
const npu_device_tensor_spec * srcs, size_t src_len) {
194185
if (op != NPU_OP_ADD && op != NPU_OP_SUB && op != NPU_OP_MUL) {
@@ -228,7 +219,7 @@ bool is_element_wise_op_supported(npu_device_tensor_op op, const npu_device_tens
228219
return false;
229220
}
230221

231-
if (!is_same_shape(src0, *dst)) {
222+
if (!hexagon::is_same_shape(src0, *dst)) {
232223
DEVICE_LOG_DEBUG("[%s]src0 and dst have different shape\n", hexagon::op_get_name(op));
233224
return false;
234225
}
@@ -271,7 +262,7 @@ void rms_norm_vec_f32(const float * src, size_t count, float eps, float * dst) {
271262
Q6_V_valign_VVR(Q6_Vqf32_vmpy_VsfVsf(curr, curr), Q6_V_vzero(), leftover_bytes));
272263
}
273264

274-
const float mean = hexagon::vec_reduction_qf32_f32(sum) / count; // TODO: figure out how to do division in vector
265+
const float mean = hexagon::vec_reduction_f32_qf32(sum) / count; // TODO: figure out how to do division in vector
275266
const float scale = 1.0f / sqrtf(mean + eps); // TODO: use buildin blas sqrtf?
276267
hexagon::vec_scale_f32(src, scale, dst, count);
277268
}
@@ -354,7 +345,7 @@ bool is_unary_op_supported(npu_device_tensor_op op, const npu_device_tensor_spec
354345
return false;
355346
}
356347

357-
if (!is_same_shape(src0, *dst)) {
348+
if (!hexagon::is_same_shape(src0, *dst)) {
358349
DEVICE_LOG_DEBUG("[%s]src0 and dst have different shape\n", hexagon::op_get_name(op));
359350
return false;
360351
}
@@ -396,7 +387,7 @@ constexpr const op_capabilities kOpCapabilities[] = {
396387
{
397388
element_wise_op<vec_op_f32_f32<vmul_f32_f32>>, // NPU_DATA_TYPE_F32
398389
element_wise_op<vec_op_f16_f16<vmul_f16_f16>>, // NPU_DATA_TYPE_F16
399-
}, false, // requires_thread_barrier
390+
}, false, // requires_thread_barrier
400391
},
401392
{
402393
NPU_OP_RMS_NORM, is_unary_op_supported,
@@ -412,6 +403,13 @@ constexpr const op_capabilities kOpCapabilities[] = {
412403
nullptr, // NPU_DATA_TYPE_F16
413404
}, true, // requires_thread_barrier
414405
},
406+
{
407+
NPU_OP_ROPE, hexagon::is_rope_supported,
408+
{
409+
hexagon::rope_f32, // NPU_DATA_TYPE_F32
410+
nullptr, // NPU_DATA_TYPE_F16
411+
}, false, // requires_thread_barrier
412+
},
415413
};
416414

417415
static_assert(kOpCapabilities[NPU_OP_MUL_MAT].compute_funcs[NPU_DATA_TYPE_F32] == hexagon::mul_mat_f32,
@@ -424,6 +422,7 @@ static_assert(kOpCapabilities[NPU_OP_RMS_NORM].op == NPU_OP_RMS_NORM,
424422
"kOpArray[NPU_OP_RMS_NORM].op != NPU_OP_RMS_NORM");
425423
static_assert(kOpCapabilities[NPU_OP_FLASH_ATTN].op == NPU_OP_FLASH_ATTN,
426424
"kOpArray[NPU_OP_FLASH_ATTN].op != NPU_OP_FLASH_ATTN");
425+
static_assert(kOpCapabilities[NPU_OP_ROPE].op == NPU_OP_ROPE, "kOpArray[NPU_OP_ROPE].op != NPU_OP_ROPE");
427426

428427
hexagon::compute_func_type get_compute_func_impl(npu_device_tensor_op op, npu_device_tensor_data_type type) {
429428
if (op >= NPU_OP_COUNT) {
@@ -451,17 +450,18 @@ bool requires_thread_barrier(npu_device_tensor_op op) {
451450

452451
bool support_op(npu_device_tensor_op op, const npu_device_tensor_spec * dst, const npu_device_tensor_spec * srcs,
453452
size_t src_len) {
454-
if (get_compute_func_impl(op, dst->type) == nullptr) {
455-
DEVICE_LOG_ERROR("[%s]unsupported, get_compute_func failed\n", op_get_name(op));
456-
return false;
457-
}
458-
459453
auto is_supported_func = kOpCapabilities[op].is_supported;
460454
if (!is_supported_func || !is_supported_func(op, dst, srcs, src_len)) {
461455
DEVICE_LOG_DEBUG("[%s]unsupported, is_supported_func return false\n", op_get_name(op));
462456
return false;
463457
}
464458

459+
if (get_compute_func_impl(op, dst->type) == nullptr) {
460+
DEVICE_LOG_DEBUG("[%s]unsupported, get_compute_func failed, type: %s\n", op_get_name(op),
461+
get_type_name(dst->type));
462+
return false;
463+
}
464+
465465
return true;
466466
}
467467

0 commit comments

Comments
 (0)