Skip to content

Commit 2cd429c

Browse files
authored
feat: perf opt part5 (#52)
* rename * Refactor vector operations in vec_op_impl and vec_dot_product_impl for improved clarity and performance * wip * Enhance vector copy functions for improved performance and clarity in vec_ops.hpp * wip * wip * wip * Optimize vector dot product implementations for enhanced performance and efficiency * Enhance flash attention implementation and type traits for improved vector operations and alignment checks # Conflicts: # ggml/src/ggml-qnn/npu/device/type_traits.cpp * remove align * wip * Enhance vector dot product implementation for improved performance by adding parallel processing for multiple vector pairs * Revert "Enhance vector dot product implementation for improved performance by adding parallel processing for multiple vector pairs" This reverts commit 78cc24ed2285002ca29d6189fa61ba4ce24f8d16. * Enhance flash attention implementation with type checks for tensor data types and improved constexpr usage * wip * opt mask calc * Revert "opt mask calc" This reverts commit bb1840876692a11511d5ab7828b8a707402e30b9. * wip * opt mul mat caching logic to add dst cache * Revert "opt mul mat caching logic to add dst cache" This reverts commit ab442fa9f763b3873c929936e4cb739cb1c83850. * wip * Refactor matrix multiplication implementation to include vector conversion and performance tracking * wip * wip * wip * create vec_ops.inl for more aggressive compiler inline * wip * refactor vector dot product implementations for improved readability and performance * refactor vector conversion functions to use HVX_Vector_Dual for improved clarity and consistency * wip * wip * wip * implement row size caching logic and enhance type traits for F32 support * refactor matrix multiplication functions to improve caching logic and simplify tensor alignment handling * add vector zeroing functions for F32 and F16 types to optimize memory initialization * Revert "add vector zeroing functions for F32 and F16 types to optimize memory initialization" This reverts commit e374326dc74d049e6603e393ade418d9ef2b83f3. * wip * refactor alignment checks in dot product function to handle null pointers * wip * refactor load_block_generic and related functions for improved alignment handling * wip * refactor flash attention implementation and introduce type-erased dot function for improved type handling * refactor dot product implementations for improved loop handling and clarity * refactor thread_pool constructor to pre-allocate VTCM cache for each thread * Revert "refactor thread_pool constructor to pre-allocate VTCM cache for each thread" This reverts commit 00cdd3f. * wip * opt interfaces for tensor cleanup * refactor mul_mat_impl to use aligned size for src0 row calculation * refactor: update dequantized_row_size logic and add size alignment checks for tensors * wip * wip * refactor: replace raw pointer initialization with invalid handle constants for better clarity * wip
1 parent fc45ad5 commit 2cd429c

File tree

17 files changed

+1000
-680
lines changed

17 files changed

+1000
-680
lines changed

ggml/src/ggml-qnn/npu/device/device.cpp

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
#include <hexagon_types.h>
55

66
#include <memory>
7-
#include <new>
87

98
#include "graph.hpp"
109
#include "hexagon_npu.h"
@@ -69,20 +68,28 @@ struct npu_device_context {
6968
}
7069
};
7170

72-
inline hexagon::tensor * tensor_from_handle(npu_device_graph_handle_t h) {
71+
inline hexagon::tensor * tensor_from_handle(npu_device_tensor_handle_t h) {
72+
if (h == npu_device_INVALID_DEVICE_TENSOR_HANDLE) {
73+
return nullptr;
74+
}
75+
7376
return reinterpret_cast<hexagon::tensor *>(h);
7477
}
7578

76-
inline npu_device_graph_handle_t tensor_to_handle(hexagon::tensor * tensor) {
77-
return reinterpret_cast<npu_device_graph_handle_t>(tensor);
79+
inline npu_device_tensor_handle_t tensor_to_handle(hexagon::tensor * tensor) {
80+
return reinterpret_cast<npu_device_tensor_handle_t>(tensor);
7881
}
7982

80-
inline hexagon::graph * graph_from_handle(npu_device_tensor_handle_t h) {
83+
inline hexagon::graph * graph_from_handle(npu_device_graph_handle_t h) {
84+
if (h == npu_device_INVALID_DEVICE_GRAPH_HANDLE) {
85+
return nullptr;
86+
}
87+
8188
return reinterpret_cast<hexagon::graph *>(h);
8289
}
8390

84-
inline npu_device_tensor_handle_t graph_to_handle(hexagon::graph * graph) {
85-
return reinterpret_cast<npu_device_tensor_handle_t>(graph);
91+
inline npu_device_graph_handle_t graph_to_handle(hexagon::graph * graph) {
92+
return reinterpret_cast<npu_device_graph_handle_t>(graph);
8693
}
8794

8895
inline npu_device_context * device_context_from_handle(remote_handle64 h) {
@@ -93,12 +100,7 @@ inline npu_device_context * device_context_from_handle(remote_handle64 h) {
93100

94101
int npu_device_open(const char * uri, remote_handle64 * h) {
95102
// TODO: should we have a device context here?
96-
auto * context = new (std::nothrow) npu_device_context();
97-
if (!context) {
98-
DEVICE_LOG_ERROR("Failed to allocate memory for the npu_device_context");
99-
return AEE_ENOMEMORY;
100-
}
101-
103+
auto * context = new npu_device_context();
102104
if (!context->init()) {
103105
DEVICE_LOG_ERROR("Failed to initialize npu_device_context");
104106
delete context;
@@ -144,12 +146,7 @@ AEEResult npu_device_device_support_op(remote_handle64 _h, npu_device_tensor_op
144146
AEEResult npu_device_tensor_init(remote_handle64 _h, const npu_device_tensor_config * info,
145147
npu_device_tensor_handle_t * tensor_handle) {
146148
NPU_UNUSED(_h);
147-
auto * tensor = new (std::nothrow) hexagon::tensor(*info);
148-
if (!tensor) {
149-
DEVICE_LOG_ERROR("Failed to allocate memory for the tensor");
150-
return AEE_ENOMEMORY;
151-
}
152-
149+
auto * tensor = new hexagon::tensor(*info);
153150
*tensor_handle = tensor_to_handle(tensor);
154151
return AEE_SUCCESS;
155152
}
@@ -177,13 +174,29 @@ AEEResult npu_device_tensor_free(remote_handle64 _h, npu_device_tensor_handle_t
177174
return AEE_SUCCESS;
178175
}
179176

180-
AEEResult npu_device_graph_init(remote_handle64 _h, npu_device_graph_handle_t * graph_handle) {
177+
AEEResult npu_device_tensors_free(remote_handle64 _h, const npu_device_tensor_handle_t * tensor_handles,
178+
int tensor_handlesLen) {
181179
NPU_UNUSED(_h);
182-
auto * graph = new (std::nothrow) hexagon::graph();
183-
if (!graph) {
184-
return AEE_ENOMEMORY;
180+
if (!tensor_handles || tensor_handlesLen < 0) {
181+
DEVICE_LOG_ERROR("npu_device_tensors_free: Invalid arguments");
182+
return AEE_EINVARGS;
183+
}
184+
185+
for (int i = 0; i < tensor_handlesLen; ++i) {
186+
auto * tensor = tensor_from_handle(tensor_handles[i]);
187+
if (tensor) {
188+
delete tensor;
189+
} else {
190+
DEVICE_LOG_ERROR("npu_device_tensors_free: Invalid tensor handle at index %d", i);
191+
}
185192
}
186193

194+
return AEE_SUCCESS;
195+
}
196+
197+
AEEResult npu_device_graph_init(remote_handle64 _h, npu_device_graph_handle_t * graph_handle) {
198+
NPU_UNUSED(_h);
199+
auto * graph = new hexagon::graph();
187200
*graph_handle = graph_to_handle(graph);
188201
return AEE_SUCCESS;
189202
}

ggml/src/ggml-qnn/npu/device/op_flash_attn.cpp

Lines changed: 46 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,19 @@ inline float f16_to_f32(const npu_device_fp16_t src) {
1313
}
1414

1515
// From: ggml/src/ggml-cpu/ops.cpp
16+
template <bool _IsKvF16>
1617
void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hexagon::tensor * k,
1718
const hexagon::tensor * v, const hexagon::tensor * mask, hexagon::compute_params * params) {
1819
static_assert(3 <= hexagon::kMaxParamsCount, "flash_attn op params count exceeds max params count");
1920

21+
constexpr const npu_device_tensor_data_type kKvDataType = _IsKvF16 ? NPU_DATA_TYPE_F16 : NPU_DATA_TYPE_F32;
22+
23+
if (k->get_type() != kKvDataType || v->get_type() != k->get_type()) {
24+
DEVICE_LOG_ERROR("flash_attn_impl: k and v must have same type, got k: %s, v: %s\n",
25+
hexagon::get_type_name(k->get_type()), hexagon::get_type_name(v->get_type()));
26+
return;
27+
}
28+
2029
float scale = out->get_op_param<float>(0);
2130
const float max_bias = out->get_op_param<float>(1);
2231
const float logit_softcap = out->get_op_param<float>(2);
@@ -37,9 +46,11 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
3746
const float m0 = powf(2.0f, -(max_bias) / n_head_log2);
3847
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
3948

40-
const auto q_to_vec_dot = hexagon::get_type_traits(k->get_type()).from_float; // TODO: fix this
41-
const auto kq_vec_dot = hexagon::get_type_traits(k->get_type()).vec_dot;
42-
if (!q_to_vec_dot || !kq_vec_dot) {
49+
const auto & k_type_traits = hexagon::get_type_traits(kKvDataType);
50+
const auto q_to_vec_dot = k_type_traits.from_float;
51+
constexpr const auto kq_vec_dot = _IsKvF16 ? hexagon::type_erase_dot_func<hexagon::vec_dot_product_f16_f16> :
52+
hexagon::type_erase_dot_func<hexagon::vec_dot_product_f32_f32>;
53+
if (!q_to_vec_dot) {
4354
DEVICE_LOG_ERROR("flash_attn_impl: unsupported data type for q, k, or v\n");
4455
return;
4556
}
@@ -50,12 +61,12 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
5061
const auto DK = k->get_ne(0);
5162
const auto DV = v->get_ne(0);
5263
const auto row_bytes_q = q->get_ne(0) * hexagon::get_type_traits(q->get_type()).type_size;
53-
const auto row_bytes_k = DK * hexagon::get_type_traits(k->get_type()).type_size;
64+
const auto row_bytes_k = DK * k_type_traits.type_size;
5465
const auto row_bytes_v = DV * hexagon::get_type_traits(v->get_type()).type_size;
5566

56-
constexpr const size_t kFloatsPerVector = hexagon::kBytesPerVector / sizeof(float);
57-
const auto aligned_dk = (DK + kFloatsPerVector - 1) / kFloatsPerVector * kFloatsPerVector;
58-
const auto aligned_dv = (DV + kFloatsPerVector - 1) / kFloatsPerVector * kFloatsPerVector;
67+
constexpr const size_t kFloatsPerVectorPair = hexagon::kBytesPerVector * 2 / sizeof(float);
68+
const auto aligned_dk = (DK + kFloatsPerVectorPair - 1) / kFloatsPerVectorPair * kFloatsPerVectorPair;
69+
const auto aligned_dv = (DV + kFloatsPerVectorPair - 1) / kFloatsPerVectorPair * kFloatsPerVectorPair;
5970
size_t total_cache_size = sizeof(float) * (aligned_dk + 2 * aligned_dv);
6071
auto * cache_ptr = params->get_vtcm_cache(total_cache_size);
6172
if (!cache_ptr) {
@@ -64,11 +75,10 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
6475
}
6576

6677
// loop over n_batch and n_head
67-
const auto rows_per_batch = q->get_ne(2) * q->get_ne(1);
68-
const auto out_rows_per_batch = out->get_ne(2) * out->get_ne(1);
69-
const bool is_v_f16 =
70-
v->get_type() == NPU_DATA_TYPE_F16; // check if V is in FP16 format, otherwise it is in FP32 format
71-
uint8_t * dst_ptr = out->get_write_buffer();
78+
constexpr bool is_v_f16 = _IsKvF16; // check if V is in FP16 format, otherwise it is in FP32 format
79+
const auto rows_per_batch = q->get_ne(2) * q->get_ne(1);
80+
const auto out_rows_per_batch = out->get_ne(2) * out->get_ne(1);
81+
uint8_t * dst_ptr = out->get_write_buffer();
7282
if (!dst_ptr) {
7383
DEVICE_LOG_ERROR("flash_attn_impl: dst_ptr is not writable, tensor: %p, type: %s\n", (void *) out,
7484
hexagon::get_type_name(out->get_type()));
@@ -80,6 +90,10 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
8090
const uint8_t * k_ptr = k->get_read_buffer();
8191
const uint8_t * v_ptr = v->get_read_buffer();
8292
const uint8_t * mask_ptr = mask ? mask->get_read_buffer() : nullptr;
93+
float * VKQ32 = reinterpret_cast<float *>(cache_ptr); // FP32 VKQ accumulator
94+
auto * VKQ16 = reinterpret_cast<npu_device_fp16_t *>(VKQ32 + aligned_dv); // (temporary) FP16 VKQ accumulator
95+
auto * Q_q = reinterpret_cast<npu_device_fp16_t *>(
96+
VKQ32 + 2 * aligned_dv); // (temporary) buffer for Q converted to quantized/FP16
8397
for (auto ir = start_end_row.first; ir < start_end_row.second; ++ir) {
8498
// q indices
8599
const auto iq3 = ir / rows_per_batch;
@@ -90,15 +104,13 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
90104
const float slope =
91105
(max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1) : 1.0f;
92106

93-
float S = 0.0f; // sum
94-
float M = -INFINITY; // maximum KQ value
107+
float S = 0.0f; // sum
108+
float M = -INFINITY; // maximum KQ value
95109

96-
float * VKQ32 = reinterpret_cast<float *>(cache_ptr); // FP32 VKQ accumulator
97-
auto * VKQ16 = reinterpret_cast<npu_device_fp16_t *>(VKQ32 + aligned_dv); // (temporary) FP16 VKQ accumulator
98-
auto * Q_q = reinterpret_cast<npu_device_fp16_t *>(
99-
VKQ32 + 2 * aligned_dv); // (temporary) buffer for Q converted to quantized/FP16
110+
const auto * q_data = q_ptr + (iq1 * q->get_nb(1) + iq2 * q->get_nb(2) + iq3 * q->get_nb(3));
111+
hexagon::l2fetch_row(q_data, row_bytes_q);
100112

101-
if (is_v_f16) {
113+
if constexpr (is_v_f16) {
102114
memset(VKQ16, 0, DV * sizeof(npu_device_fp16_t));
103115
} else {
104116
memset(VKQ32, 0, DV * sizeof(float));
@@ -117,16 +129,13 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
117129
const int iv3 = iq3 / rv3;
118130
const int iv2 = iq2 / rv2;
119131

120-
const auto * q_data = q_ptr + (iq1 * q->get_nb(1) + iq2 * q->get_nb(2) + iq3 * q->get_nb(3));
121-
if (iq1 < q->get_ne(1) - 1) {
122-
hexagon::l2fetch_row(q_data + q->get_nb(1), row_bytes_q);
123-
}
124-
125132
q_to_vec_dot(reinterpret_cast<const float *>(q_data), Q_q, DK);
126133

127134
// online softmax / attention
128135
// loop over n_kv and n_head_kv
129136
// ref: https://arxiv.org/pdf/2112.05682.pdf
137+
const auto * k_plane_ptr = k_ptr + ik2 * k->get_nb(2) + ik3 * k->get_nb(3);
138+
const auto * v_plane_ptr = v_ptr + iv2 * v->get_nb(2) + iv3 * v->get_nb(3);
130139
for (int64_t ic = 0; ic < k->get_ne(1); ++ic) {
131140
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(flash_attn, 0, loop);
132141
float mv = mp ? (slope * f16_to_f32(mp[ic])) : 0.0f;
@@ -137,7 +146,7 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
137146
float s = 0.f;
138147
{
139148
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(flash_attn, 1, kq_dot);
140-
const auto * k_data = k_ptr + (ic * k->get_nb(1) + ik2 * k->get_nb(2) + ik3 * k->get_nb(3));
149+
const auto * k_data = k_plane_ptr + ic * k->get_nb(1);
141150
if (ic < k->get_ne(1) - 1) {
142151
hexagon::l2fetch_row(k_data + k->get_nb(1), row_bytes_k);
143152
}
@@ -156,12 +165,12 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
156165
float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
157166
float vs = 1.0f; // post-softmax KQ value, expf(s - M)
158167

159-
const auto * v_data = v_ptr + (ic * v->get_nb(1) + iv2 * v->get_nb(2) + iv3 * v->get_nb(3));
168+
const auto * v_data = v_plane_ptr + ic * v->get_nb(1);
160169
if (ic < v->get_ne(1)) {
161170
hexagon::l2fetch_row(v_data, row_bytes_v);
162171
}
163172

164-
if (is_v_f16) {
173+
if constexpr (is_v_f16) {
165174
if (s > M) {
166175
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
167176
M = s;
@@ -201,7 +210,7 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
201210
S = S * ms + vs; // scale and increment sum with partial sum
202211
}
203212

204-
if (is_v_f16) {
213+
if constexpr (is_v_f16) {
205214
// TODO: use a more efficient conversion
206215
for (int64_t d = 0; d < DV; ++d) {
207216
VKQ32[d] = f16_to_f32(VKQ16[d]);
@@ -218,7 +227,10 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
218227
const int i3 = iq3;
219228

220229
// permute(0, 2, 1, 3)
221-
memcpy(dst_ptr + (i3 * out_rows_per_batch + i2 + i1 * out->get_ne(1)) * out->get_nb(1), VKQ32, out->get_nb(1));
230+
hexagon::vec_cpy_f32(
231+
reinterpret_cast<const float *>(VKQ32),
232+
reinterpret_cast<float *>(dst_ptr + (i3 * out_rows_per_batch + i2 + i1 * out->get_ne(1)) * out->get_nb(1)),
233+
out->get_ne(0));
222234
}
223235

224236
out->release_write_buffer(); // mark the output tensor as modified
@@ -244,7 +256,11 @@ bool flash_attn_f32(tensor * out, compute_params * params) {
244256
return false;
245257
}
246258

247-
flash_attn_impl(out, q, k, v, mask, params);
259+
if (k->get_type() == NPU_DATA_TYPE_F16) {
260+
flash_attn_impl<true>(out, q, k, v, mask, params);
261+
} else {
262+
flash_attn_impl<false>(out, q, k, v, mask, params);
263+
}
248264
return true;
249265
}
250266

ggml/src/ggml-qnn/npu/device/op_impl.cpp

Lines changed: 8 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -12,64 +12,10 @@
1212

1313
namespace {
1414

15-
template <HVX_Vector (*_OpIntrinsic)(HVX_Vector, HVX_Vector), typename _TyData>
16-
inline void vec_op_impl(const _TyData * src0, const _TyData * src1, size_t count, _TyData * dst) {
17-
constexpr const size_t kElementsPerVector = hexagon::kBytesPerVector / sizeof(_TyData);
18-
19-
HVX_Vector * iptr0 = ((HVX_Vector *) src0);
20-
HVX_Vector * const iptr0_end = ((HVX_Vector *) src0) + (count / kElementsPerVector);
21-
HVX_Vector * iptr1 = ((HVX_Vector *) src1);
22-
HVX_Vector * optr = ((HVX_Vector *) dst); // framework will ensure the dst is aligned
23-
HVX_Vector prev0 = *iptr0++;
24-
HVX_Vector prev1 = *iptr1++;
25-
26-
while (iptr0 < iptr0_end) {
27-
HVX_Vector curr0 = *iptr0++;
28-
HVX_Vector curr1 = *iptr1++;
29-
HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
30-
HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
31-
*optr++ = _OpIntrinsic(s0, s1);
32-
prev0 = curr0;
33-
prev1 = curr1;
34-
}
35-
36-
const size_t leftover = count % kElementsPerVector;
37-
if ((iptr0_end - ((HVX_Vector *) src0)) > 0) {
38-
// handle the last vector
39-
// see also:
40-
// https://github.com/UbiquitousLearning/mllm/blob/babf4410352ce8730824c87699c025a0d4ce3a6f/src/backends/qnn/LLaMAOpPackageHtp/LLaMAPackage/src/ops/LLaMAMul.cpp#L147
41-
// or qualcomm sdk libs\qhl_hvx\src\qhblas_hvx\qhblas_hvx_aw_vector_add_ah.c
42-
bool should_fetch_src0 = leftover != 0 || !hexagon::is_addr_aligned(iptr0);
43-
bool should_fetch_src1 = leftover != 0 || !hexagon::is_addr_aligned(iptr1);
44-
HVX_Vector curr0 = should_fetch_src0 ? *iptr0 : prev0;
45-
HVX_Vector curr1 = should_fetch_src1 ? *iptr1 : prev1;
46-
iptr0 += should_fetch_src0 ? 1 : 0;
47-
iptr1 += should_fetch_src1 ? 1 : 0;
48-
HVX_Vector s0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
49-
HVX_Vector s1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
50-
*optr++ = _OpIntrinsic(s0, s1);
51-
prev0 = curr0;
52-
prev1 = curr1;
53-
}
54-
55-
const size_t leftover_bytes = leftover * sizeof(_TyData);
56-
if (leftover > 0) {
57-
// handle the leftover elements
58-
HVX_Vector curr0 =
59-
(leftover_bytes + hexagon::unaligned_bytes(iptr0) > hexagon::kBytesPerVector) ? *iptr0 : prev0;
60-
curr0 = Q6_V_valign_VVR(curr0, prev0, (size_t) src0);
61-
62-
HVX_Vector curr1 =
63-
(leftover_bytes + hexagon::unaligned_bytes(iptr1) > hexagon::kBytesPerVector) ? *iptr1 : prev1;
64-
curr1 = Q6_V_valign_VVR(curr1, prev1, (size_t) src1);
65-
66-
hexagon::q6op_vstu_variable_ARV(optr, leftover_bytes, _OpIntrinsic(curr0, curr1));
67-
}
68-
}
69-
70-
template <HVX_Vector (*_OpIntrinsic)(HVX_Vector, HVX_Vector)>
15+
template <HVX_Vector (*_OpBinaryTransform)(HVX_Vector, HVX_Vector)>
7116
inline void vec_op_f32_f32(const float * src0, const float * src1, size_t count, float * dst) {
72-
vec_op_impl<_OpIntrinsic, float>(src0, src1, count, dst);
17+
using namespace hexagon::vec;
18+
vec_trans_op_impl<_OpBinaryTransform, float>(src0, src1, count, dst);
7319
}
7420

7521
inline HVX_Vector vadd_f32_f32(HVX_Vector a, HVX_Vector b) {
@@ -84,10 +30,11 @@ inline HVX_Vector vmul_f32_f32(HVX_Vector a, HVX_Vector b) {
8430
return Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b));
8531
}
8632

87-
template <HVX_Vector (*_OpIntrinsic)(HVX_Vector, HVX_Vector)>
33+
template <HVX_Vector (*_OpBinaryTransform)(HVX_Vector, HVX_Vector)>
8834
inline void vec_op_f16_f16(const npu_device_fp16_t * src0, const npu_device_fp16_t * src1, size_t count,
8935
npu_device_fp16_t * dst) {
90-
vec_op_impl<_OpIntrinsic, npu_device_fp16_t>(src0, src1, count, dst);
36+
using namespace hexagon::vec;
37+
vec_trans_op_impl<_OpBinaryTransform, npu_device_fp16_t>(src0, src1, count, dst);
9138
}
9239

9340
inline HVX_Vector vadd_f16_f16(HVX_Vector a, HVX_Vector b) {
@@ -252,10 +199,10 @@ void rms_norm_vec_f32(const float * src, size_t count, float eps, float * dst) {
252199
prev = curr;
253200
}
254201

255-
const size_t leftover_bytes = leftover * sizeof(float);
256202
if (leftover > 0) {
257203
// handle the leftover elements
258-
HVX_Vector curr =
204+
const size_t leftover_bytes = leftover * sizeof(float);
205+
HVX_Vector curr =
259206
(leftover_bytes + hexagon::unaligned_bytes(src_vec_ptr) > hexagon::kBytesPerVector) ? *src_vec_ptr : prev;
260207
curr = Q6_V_valign_VVR(curr, prev, (size_t) src);
261208
sum = Q6_Vqf32_vadd_Vqf32Vqf32(sum,

0 commit comments

Comments
 (0)