Skip to content

Commit cd35b35

Browse files
committed
add overload function mli_prv_load_mac for mac scalar on array's element
1 parent 102ef03 commit cd35b35

File tree

2 files changed

+39
-2
lines changed

2 files changed

+39
-2
lines changed

lib/src/kernels/eltwise/mli_krn_eltwise.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ static inline void __attribute__ ((always_inline)) eltwise_op_mul_fx (
357357
if ((out_size & 0x3) || (out_size < 0x7)) {
358358
for (int j = 0; j < (out_size & 0x3); j++) {
359359
auto acc = mli_prv_init_accu((io_T)0);
360-
mli_prv_load_mac(&acc, vec++, (const io_T *__restrict) &broadcast_val);
360+
mli_prv_load_mac(&acc, vec++, broadcast_val);
361361
mli_prv_clip_and_store_output(out++, &acc, mul_out_shift);
362362
}
363363
for (int j = 0; j < (out_size & ~0x3) / 2; j++) {
@@ -431,7 +431,7 @@ static inline void __attribute__ ((always_inline)) eltwise_op_mul_with_restricts
431431
if ((out_size & 0x3) || (out_size < 0x7)) {
432432
for (int j = 0; j < (out_size & 0x3); j++) {
433433
auto acc = mli_prv_init_accu((io_T)0);
434-
mli_prv_load_mac(&acc, vec++, (const io_T *__restrict) &broadcast_val);
434+
mli_prv_load_mac(&acc, vec++, broadcast_val);
435435
mli_prv_clip_and_store_output(out++, &acc, mul_out_shift);
436436
}
437437
for (int j = 0; j < (out_size & ~0x3) / 2; j++) {

lib/src/private/mli_prv_dsp.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,12 @@ static inline void __attribute__ ((always_inline)) mli_prv_load_mac(
6767
MLI_PTR(in_T) __restrict in,
6868
MLI_PTR(w_T) __restrict k);
6969

70+
template < typename in_T, typename w_T, typename acc_T >
71+
static inline void __attribute__ ((always_inline)) mli_prv_load_mac(
72+
acc_T * accu,
73+
MLI_PTR(in_T) __restrict in,
74+
w_T k);
75+
7076
template < typename in_T, typename w_T, typename acc_T >
7177
static inline void __attribute__ ((always_inline)) mli_prv_load_mac_vec2(
7278
acc_T * accu,
@@ -550,6 +556,37 @@ static inline void __attribute__ ((always_inline)) mli_prv_load_mac(
550556
*accu = _dmachbl(*in, *(MLI_PTR(uint8_t)) k);
551557
}
552558

559+
static inline void __attribute__ ((always_inline)) mli_prv_load_mac(
560+
accum40_t * accu,
561+
const MLI_PTR(int16_t) __restrict in,
562+
const int16_t k) {
563+
*accu = fx_a40_mac_q15(*accu, *in, k);
564+
}
565+
566+
static inline void __attribute__ ((always_inline)) mli_prv_load_mac(
567+
int32_t * accu,
568+
const MLI_PTR(int8_t) __restrict in,
569+
const int8_t k) {
570+
/* casting the in pointer to unsigned to make sure no sign extension happens on the load
571+
* this way the 'second' byte contains zeros. and it is safe to use dmac.
572+
* the sign extension happens inside the dmachbl operation.
573+
* for the load of 'k' we need sign extension because we need a 16bit value.
574+
* the value of the second half is don't care because it will be multiplied by 0
575+
*/
576+
*accu = _dmachbl(k, *(MLI_PTR(uint8_t)) in);
577+
}
578+
579+
static inline void __attribute__ ((always_inline)) mli_prv_load_mac(
580+
int32_t * accu, const MLI_PTR(int16_t) __restrict in,
581+
const int8_t k) {
582+
/* casting the in pointer to unsigned to make sure no sign extension happens on the load
583+
* this way the 'second' byte contains zeros. and it is safe to use dmac.
584+
* the sign extension happens inside the dmachbl operation.
585+
* for the load of 'in' we need sign extension because we need a 16bit value.
586+
* the value of the second half is don't care because it will be multiplied by 0
587+
*/
588+
*accu = _dmachbl(*in, (uint8_t)k);
589+
}
553590

554591
static inline void __attribute__ ((always_inline)) mli_prv_load_mac_vec2(
555592
accum40_t * accu,

0 commit comments

Comments
 (0)