Skip to content

Commit 379bdeb

Browse files
authored
feat: perf opt gemv (#54)
* add GEMV implementation for matrix multiplication in hexagon * refactor: optimize GEMV implementation for matrix multiplication in hexagon * wip * refactor: enhance caching mechanism in GEMV implementation for matrix multiplication * wip * refactor: streamline caching logic in GEMV implementation for matrix multiplication * wip * wip * fix broadcase in flash_attn * format * refactor: optimize memory fetching in matrix multiplication implementations * wip * fix aligned gemv * rename * refactor: remove unused memory cache functions and initialize VTCM cache * wip * feat: add vector math functions for IEEE float and half float operations * feat: add vec_silu_f32 and vec_silu_f16 functions for SiLU activation * feat: implement GLU operation support in tensor processing * feat: add GLU operation support and related enhancements in tensor processing * wip * wip * wip * feat: add qhmath_hvx_div_vf functions for f32 vector operations * feat: add qhmath_hvx_div_vhf functions for f16 vector operations * fix: reorder parameters in vector operation functions for consistency * wip * feat: enhance vector operations with parameterized transformations and improved GLU implementations * wip * fix: increase default stack size and correct thread parameter indexing in thread pool * fix f16 div * fix f32 div * fix: update GLU vector operations to use explicit denominator calculation * wip * wip * Refactor cacheability check for matrix multiplication to handle multiple source tensors * Revert "fix: increase default stack size and correct thread parameter indexing in thread pool" This reverts commit 40e3f09. * wip * fix comments * replace copy with memcpy
1 parent 6260c31 commit 379bdeb

24 files changed

+2334
-440
lines changed

ggml/src/ggml-qnn/npu/device/device.cpp

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,4 @@
11

2-
#include <AEEStdErr.h>
3-
#include <HAP_compute_res.h>
4-
#include <hexagon_types.h>
5-
6-
#include <memory>
7-
82
#include "graph.hpp"
93
#include "hexagon_npu.h"
104
#include "op_impl.hpp"
@@ -14,6 +8,12 @@
148
#include "type_traits.hpp"
159
#include "util.hpp"
1610

11+
#include <AEEStdErr.h>
12+
#include <HAP_compute_res.h>
13+
#include <hexagon_types.h>
14+
15+
#include <memory>
16+
1717
namespace {
1818

1919
struct npu_device_context {
@@ -130,28 +130,34 @@ AEEResult npu_device_device_get_alignment(remote_handle64 _h, uint32_t * alignme
130130
return AEE_SUCCESS;
131131
}
132132

133-
AEEResult npu_device_device_support_op(remote_handle64 _h, npu_device_tensor_op op, const npu_device_tensor_spec * dst,
134-
const npu_device_tensor_spec * srcs, int srcsLen, boolean * is_supported) {
133+
AEEResult npu_device_device_support_op(remote_handle64 _h,
134+
const npu_device_tensor_op_spec * op_spec,
135+
const npu_device_tensor_spec * dst,
136+
const npu_device_tensor_spec * srcs,
137+
int srcsLen,
138+
boolean * is_supported) {
135139
NPU_UNUSED(_h);
136140

137141
if (!srcs || srcsLen <= 0 || !dst || !is_supported) {
138142
DEVICE_LOG_ERROR("npu_device_device_support_op: Invalid arguments");
139143
return AEE_EINVARGS;
140144
}
141145

142-
*is_supported = hexagon::support_op(op, dst, srcs, srcsLen);
146+
*is_supported = hexagon::support_op(op_spec, dst, srcs, srcsLen);
143147
return AEE_SUCCESS;
144148
}
145149

146-
AEEResult npu_device_tensor_init(remote_handle64 _h, const npu_device_tensor_config * info,
147-
npu_device_tensor_handle_t * tensor_handle) {
150+
AEEResult npu_device_tensor_init(remote_handle64 _h,
151+
const npu_device_tensor_config * info,
152+
npu_device_tensor_handle_t * tensor_handle) {
148153
NPU_UNUSED(_h);
149154
auto * tensor = new hexagon::tensor(*info);
150155
*tensor_handle = tensor_to_handle(tensor);
151156
return AEE_SUCCESS;
152157
}
153158

154-
AEEResult npu_device_tensor_update_params(remote_handle64 _h, npu_device_tensor_handle_t tensor_handle,
159+
AEEResult npu_device_tensor_update_params(remote_handle64 _h,
160+
npu_device_tensor_handle_t tensor_handle,
155161
const npu_device_tensor_update_config * config) {
156162
NPU_UNUSED(_h);
157163
auto * tensor = tensor_from_handle(tensor_handle);
@@ -174,8 +180,9 @@ AEEResult npu_device_tensor_free(remote_handle64 _h, npu_device_tensor_handle_t
174180
return AEE_SUCCESS;
175181
}
176182

177-
AEEResult npu_device_tensors_free(remote_handle64 _h, const npu_device_tensor_handle_t * tensor_handles,
178-
int tensor_handlesLen) {
183+
AEEResult npu_device_tensors_free(remote_handle64 _h,
184+
const npu_device_tensor_handle_t * tensor_handles,
185+
int tensor_handlesLen) {
179186
NPU_UNUSED(_h);
180187
if (!tensor_handles || tensor_handlesLen < 0) {
181188
DEVICE_LOG_ERROR("npu_device_tensors_free: Invalid arguments");
@@ -201,8 +208,10 @@ AEEResult npu_device_graph_init(remote_handle64 _h, npu_device_graph_handle_t *
201208
return AEE_SUCCESS;
202209
}
203210

204-
AEEResult npu_device_graph_set_tensor(remote_handle64 _h, npu_device_graph_handle_t graph_handle,
205-
const npu_device_tensor_handle_t * tensor_handles, int tensor_handlesLen) {
211+
AEEResult npu_device_graph_set_tensor(remote_handle64 _h,
212+
npu_device_graph_handle_t graph_handle,
213+
const npu_device_tensor_handle_t * tensor_handles,
214+
int tensor_handlesLen) {
206215
NPU_UNUSED(_h);
207216
auto * graph = graph_from_handle(graph_handle);
208217
if (!graph || !tensor_handles || tensor_handlesLen <= 0) {
@@ -213,7 +222,8 @@ AEEResult npu_device_graph_set_tensor(remote_handle64 _h, npu_device_graph_handl
213222
return AEE_SUCCESS;
214223
}
215224

216-
AEEResult npu_device_graph_set_tensor_with_param(remote_handle64 _h, npu_device_graph_handle_t graph_handle,
225+
AEEResult npu_device_graph_set_tensor_with_param(remote_handle64 _h,
226+
npu_device_graph_handle_t graph_handle,
217227
const npu_device_tensor_handle_t * tensor_handles,
218228
int tensor_handlesLen,
219229
const npu_device_tensor_update_config * tensor_params,

ggml/src/ggml-qnn/npu/device/graph.cpp

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11

22
#include "graph.hpp"
33

4-
#include <new>
5-
64
#include "op_impl.hpp"
75
#include "util.hpp"
86
#include "vtcm_mem.hpp"
97

8+
#include <new>
9+
1010
namespace hexagon {
1111

1212
graph::graph() noexcept {
@@ -30,8 +30,12 @@ void graph::set_tensor(const npu_device_tensor_handle_t * tensors, int tensor_co
3030
for (int i = 0; i < tensor_count; ++i) {
3131
auto * tensor_obj = reinterpret_cast<tensor *>(tensors[i]);
3232
_tensors[i] = tensor_obj;
33-
DEVICE_LOG_DEBUG("graph(%p) set_tensor[%d]: %p(%p,%p), op: %s\n", (void *) this, i, (void *) tensor_obj,
34-
(void *) tensor_obj->get_src(0), (void *) tensor_obj->get_src(1),
33+
DEVICE_LOG_DEBUG("graph(%p) set_tensor[%d]: %p(%p,%p), op: %s\n",
34+
(void *) this,
35+
i,
36+
(void *) tensor_obj,
37+
(void *) tensor_obj->get_src(0),
38+
(void *) tensor_obj->get_src(1),
3539
op_get_name(tensor_obj->get_op()));
3640
}
3741

@@ -64,8 +68,9 @@ bool graph::compute(default_thread_pool * thread_pool, const float * f16_to_f32_
6468
return true;
6569
}
6670

67-
void graph::thread_pool_task(default_thread_pool * pool, default_thread_pool::thread_params * thread_params,
68-
void * graph) {
71+
void graph::thread_pool_task(default_thread_pool * pool,
72+
default_thread_pool::thread_params * thread_params,
73+
void * graph) {
6974
reinterpret_cast<hexagon::graph *>(graph)->compute_impl(pool, thread_params);
7075
}
7176

@@ -86,8 +91,11 @@ void graph::compute_impl(default_thread_pool * pool, default_thread_pool::thread
8691

8792
const bool should_sync = requires_thread_barrier(op);
8893
if (pool && should_sync && i < _tensor_count - 1) {
89-
DEVICE_SCOPED_PERFORMANCE_TRACKER("[%p]sync_thread, tidx: %zu, tensor[%zu/%zu]", (void *) this,
90-
params.get_thread_index(), i, _tensor_count);
94+
DEVICE_SCOPED_PERFORMANCE_TRACKER("[%p]sync_thread, tidx: %zu, tensor[%zu/%zu]",
95+
(void *) this,
96+
params.get_thread_index(),
97+
i,
98+
_tensor_count);
9199
pool->sync_thread();
92100
}
93101
}

ggml/src/ggml-qnn/npu/device/graph.hpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
#pragma once
22

3-
#include <memory>
4-
53
#include "hexagon_npu.h"
64
#include "tensor.hpp"
75
#include "thread_pool.hpp"
86

7+
#include <memory>
8+
99
namespace hexagon {
1010

1111
class graph {
@@ -20,8 +20,9 @@ class graph {
2020
bool compute(default_thread_pool * thread_pool, const float * f16_to_f32_table);
2121

2222
private:
23-
static void thread_pool_task(default_thread_pool * pool, default_thread_pool::thread_params * thread_params,
24-
void * graph);
23+
static void thread_pool_task(default_thread_pool * pool,
24+
default_thread_pool::thread_params * thread_params,
25+
void * graph);
2526
void compute_impl(default_thread_pool * pool, default_thread_pool::thread_params * thread_params);
2627

2728
std::unique_ptr<tensor *[]> _tensors;

ggml/src/ggml-qnn/npu/device/op_flash_attn.cpp

Lines changed: 49 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,20 @@ inline float f16_to_f32(const npu_device_fp16_t src) {
1414

1515
// From: ggml/src/ggml-cpu/ops.cpp
1616
template <bool _IsKvF16>
17-
void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hexagon::tensor * k,
18-
const hexagon::tensor * v, const hexagon::tensor * mask, hexagon::compute_params * params) {
17+
void flash_attn_impl(hexagon::tensor * out,
18+
const hexagon::tensor * q,
19+
const hexagon::tensor * k,
20+
const hexagon::tensor * v,
21+
const hexagon::tensor * mask,
22+
hexagon::compute_params * params) {
1923
static_assert(3 <= hexagon::kMaxParamsCount, "flash_attn op params count exceeds max params count");
2024

2125
constexpr const npu_device_tensor_data_type kKvDataType = _IsKvF16 ? NPU_DATA_TYPE_F16 : NPU_DATA_TYPE_F32;
2226

2327
if (k->get_type() != kKvDataType || v->get_type() != k->get_type()) {
2428
DEVICE_LOG_ERROR("flash_attn_impl: k and v must have same type, got k: %s, v: %s\n",
25-
hexagon::get_type_name(k->get_type()), hexagon::get_type_name(v->get_type()));
29+
hexagon::get_type_name(k->get_type()),
30+
hexagon::get_type_name(v->get_type()));
2631
return;
2732
}
2833

@@ -80,7 +85,8 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
8085
const auto out_rows_per_batch = out->get_ne(2) * out->get_ne(1);
8186
uint8_t * dst_ptr = out->get_write_buffer();
8287
if (!dst_ptr) {
83-
DEVICE_LOG_ERROR("flash_attn_impl: dst_ptr is not writable, tensor: %p, type: %s\n", (void *) out,
88+
DEVICE_LOG_ERROR("flash_attn_impl: dst_ptr is not writable, tensor: %p, type: %s\n",
89+
(void *) out,
8490
hexagon::get_type_name(out->get_type()));
8591
return;
8692
}
@@ -118,7 +124,8 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
118124

119125
const npu_device_fp16_t * mp =
120126
mask_ptr ? reinterpret_cast<const npu_device_fp16_t *>(mask_ptr + iq1 * mask->get_nb(1) +
121-
(iq3 % mask->get_ne(2)) * mask->get_nb(2)) :
127+
(iq2 % mask->get_ne(2)) * mask->get_nb(2) +
128+
(iq3 % mask->get_ne(3)) * mask->get_nb(3)) :
122129
nullptr;
123130

124131
// k indices
@@ -251,8 +258,8 @@ bool flash_attn_f32(tensor * out, compute_params * params) {
251258
const auto * v = out->get_src(2);
252259
const auto * mask = out->get_src(3);
253260
if (!q || !k || !v || !mask) {
254-
DEVICE_LOG_DEBUG("invalid src tensors: q: %p, k: %p, v: %p, mask: %p\n", (void *) q, (void *) k, (void *) v,
255-
(void *) mask);
261+
DEVICE_LOG_DEBUG(
262+
"invalid src tensors: q: %p, k: %p, v: %p, mask: %p\n", (void *) q, (void *) k, (void *) v, (void *) mask);
256263
return false;
257264
}
258265

@@ -264,8 +271,11 @@ bool flash_attn_f32(tensor * out, compute_params * params) {
264271
return true;
265272
}
266273

267-
bool is_flash_attn_supported(npu_device_tensor_op op, const npu_device_tensor_spec * dst,
268-
const npu_device_tensor_spec * srcs, size_t src_len) {
274+
bool is_flash_attn_supported(const npu_device_tensor_op_spec * op_spec,
275+
const npu_device_tensor_spec * dst,
276+
const npu_device_tensor_spec * srcs,
277+
size_t src_len) {
278+
const auto op = op_spec->op;
269279
if (op != NPU_OP_FLASH_ATTN) {
270280
DEVICE_LOG_DEBUG("op is not NPU_OP_FLASH_ATTN: %d\n", op);
271281
return false;
@@ -295,7 +305,9 @@ bool is_flash_attn_supported(npu_device_tensor_op op, const npu_device_tensor_sp
295305

296306
const auto * v = &srcs[2];
297307
if (v->type != k->type) { // TODO: support more v types
298-
DEVICE_LOG_DEBUG("[%s]v type is not the same as k: %s vs %s\n", op_get_name(op), get_type_name(v->type),
308+
DEVICE_LOG_DEBUG("[%s]v type is not the same as k: %s vs %s\n",
309+
op_get_name(op),
310+
get_type_name(v->type),
299311
get_type_name(k->type));
300312
return false;
301313
}
@@ -310,28 +322,42 @@ bool is_flash_attn_supported(npu_device_tensor_op op, const npu_device_tensor_sp
310322
DEVICE_LOG_DEBUG(
311323
"[%s]dst shape does not match q and v: dst ne: %ld, %ld, %ld, %ld, q ne: %ld, %ld, %ld, %ld, "
312324
"v ne: %ld, %ld, %ld, %ld\n",
313-
op_get_name(op), dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], q->ne[0], q->ne[1], q->ne[2], q->ne[3],
314-
v->ne[0], v->ne[1], v->ne[2], v->ne[3]);
325+
op_get_name(op),
326+
dst->ne[0],
327+
dst->ne[1],
328+
dst->ne[2],
329+
dst->ne[3],
330+
q->ne[0],
331+
q->ne[1],
332+
q->ne[2],
333+
q->ne[3],
334+
v->ne[0],
335+
v->ne[1],
336+
v->ne[2],
337+
v->ne[3]);
315338
return false;
316339
}
317340

318341
if (is_transposed_or_permuted(dst->nb)) {
319-
DEVICE_LOG_DEBUG("[%s]dst cannot be transposed or permuted, nb: %zu, %zu, %zu, %zu\n", op_get_name(op),
320-
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3]);
342+
DEVICE_LOG_DEBUG("[%s]dst cannot be transposed or permuted, nb: %zu, %zu, %zu, %zu\n",
343+
op_get_name(op),
344+
dst->nb[0],
345+
dst->nb[1],
346+
dst->nb[2],
347+
dst->nb[3]);
321348
return false;
322349
}
323350

324351
if (q->ne[0] != k->ne[0]) {
325352
DEVICE_LOG_DEBUG("[%s]q and k shapes do not match: q ne: %ld, %ld, %ld, %ld, k ne: %ld, %ld, %ld, %ld\n",
326-
op_get_name(op), q->ne[0], q->ne[1], q->ne[2], q->ne[3], k->ne[0], k->ne[1], k->ne[2],
327-
k->ne[3]);
328-
return false;
329-
}
330-
331-
if (q->ne[2] != k->ne[2] || q->ne[3] != k->ne[3] || q->ne[3] != 1) {
332-
// TODO: add broadcast support
333-
DEVICE_LOG_DEBUG("[%s]q and k shapes do not match: q ne: %ld, %ld, %ld, %ld, k ne: %ld, %ld, %ld, %ld\n",
334-
op_get_name(op), q->ne[0], q->ne[1], q->ne[2], q->ne[3], k->ne[0], k->ne[1], k->ne[2],
353+
op_get_name(op),
354+
q->ne[0],
355+
q->ne[1],
356+
q->ne[2],
357+
q->ne[3],
358+
k->ne[0],
359+
k->ne[1],
360+
k->ne[2],
335361
k->ne[3]);
336362
return false;
337363
}

ggml/src/ggml-qnn/npu/device/op_flash_attn.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
namespace hexagon {
66

77
bool flash_attn_f32(tensor * out, compute_params * params);
8-
bool is_flash_attn_supported(npu_device_tensor_op op, const npu_device_tensor_spec * dst,
9-
const npu_device_tensor_spec * srcs, size_t src_len);
8+
bool is_flash_attn_supported(const npu_device_tensor_op_spec * op_spec,
9+
const npu_device_tensor_spec * dst,
10+
const npu_device_tensor_spec * srcs,
11+
size_t src_len);
1012

1113
} // namespace hexagon

0 commit comments

Comments
 (0)