@@ -87,7 +87,7 @@ static inline int32_t dot7u_inner(const int8_t* a, const int8_t* b, const int32_
8787 return vaddvq_s32 (vaddq_s32 (acc5, acc6));
8888}
8989
90- EXPORT int32_t vec_dot7u (int8_t * a, int8_t * b, const int32_t dims) {
90+ EXPORT int32_t vec_dot7u (const int8_t * a, const int8_t * b, const int32_t dims) {
9191 int32_t res = 0 ;
9292 int i = 0 ;
9393 if (dims > DOT7U_STRIDE_BYTES_LEN) {
@@ -100,6 +100,103 @@ EXPORT int32_t vec_dot7u(int8_t* a, int8_t* b, const int32_t dims) {
100100 return res;
101101}
102102
103+ template <int32_t (*mapper)(const int32_t , const int32_t *)>
104+ static inline void dot7u_inner_bulk (const int8_t * a, const int8_t * b, const int32_t dims, const int32_t * offsets, const int32_t count, f32_t * results) {
105+ size_t blk = dims & ~15 ;
106+ size_t c = 0 ;
107+
108+ // Process 4 vectors at a time
109+ for (; c + 3 < count; c += 4 ) {
110+ const int8_t * a0 = a + mapper (c, offsets) * dims;
111+ const int8_t * a1 = a + mapper (c + 1 , offsets) * dims;
112+ const int8_t * a2 = a + mapper (c + 2 , offsets) * dims;
113+ const int8_t * a3 = a + mapper (c + 3 , offsets) * dims;
114+
115+ int32x4_t acc0 = vdupq_n_s32 (0 );
116+ int32x4_t acc1 = vdupq_n_s32 (0 );
117+ int32x4_t acc2 = vdupq_n_s32 (0 );
118+ int32x4_t acc3 = vdupq_n_s32 (0 );
119+ int32x4_t acc4 = vdupq_n_s32 (0 );
120+ int32x4_t acc5 = vdupq_n_s32 (0 );
121+ int32x4_t acc6 = vdupq_n_s32 (0 );
122+ int32x4_t acc7 = vdupq_n_s32 (0 );
123+
124+ for (size_t i = 0 ; i < blk; i += 16 ) {
125+ int8x16_t vb = vld1q_s8 (b + i);
126+
127+ int8x16_t v0 = vld1q_s8 (a0 + i);
128+ int16x8_t lo0 = vmull_s8 (vget_low_s8 (v0), vget_low_s8 (vb));
129+ int16x8_t hi0 = vmull_s8 (vget_high_s8 (v0), vget_high_s8 (vb));
130+ acc0 = vpadalq_s16 (acc0, lo0);
131+ acc1 = vpadalq_s16 (acc1, hi0);
132+
133+ int8x16_t v1 = vld1q_s8 (a1 + i);
134+ int16x8_t lo1 = vmull_s8 (vget_low_s8 (v1), vget_low_s8 (vb));
135+ int16x8_t hi1 = vmull_s8 (vget_high_s8 (v1), vget_high_s8 (vb));
136+ acc2 = vpadalq_s16 (acc2, lo1);
137+ acc3 = vpadalq_s16 (acc3, hi1);
138+
139+ int8x16_t v2 = vld1q_s8 (a2 + i);
140+ int16x8_t lo2 = vmull_s8 (vget_low_s8 (v2), vget_low_s8 (vb));
141+ int16x8_t hi2 = vmull_s8 (vget_high_s8 (v2), vget_high_s8 (vb));
142+ acc4 = vpadalq_s16 (acc4, lo2);
143+ acc5 = vpadalq_s16 (acc5, hi2);
144+
145+ int8x16_t v3 = vld1q_s8 (a3 + i);
146+ int16x8_t lo3 = vmull_s8 (vget_low_s8 (v3), vget_low_s8 (vb));
147+ int16x8_t hi3 = vmull_s8 (vget_high_s8 (v3), vget_high_s8 (vb));
148+ acc6 = vpadalq_s16 (acc6, lo3);
149+ acc7 = vpadalq_s16 (acc7, hi3);
150+ }
151+ int32x4_t acc01 = vaddq_s32 (acc0, acc1);
152+ int32x4_t acc23 = vaddq_s32 (acc2, acc3);
153+ int32x4_t acc45 = vaddq_s32 (acc4, acc5);
154+ int32x4_t acc67 = vaddq_s32 (acc6, acc7);
155+
156+ int32_t acc_scalar0 = vaddvq_s32 (acc01);
157+ int32_t acc_scalar1 = vaddvq_s32 (acc23);
158+ int32_t acc_scalar2 = vaddvq_s32 (acc45);
159+ int32_t acc_scalar3 = vaddvq_s32 (acc67);
160+ if (blk != dims) {
161+ // scalar tail
162+ for (size_t t = blk; t < dims; t++) {
163+ const int8_t bb = b[t];
164+ acc_scalar0 += a0[t] * bb;
165+ acc_scalar1 += a1[t] * bb;
166+ acc_scalar2 += a2[t] * bb;
167+ acc_scalar3 += a3[t] * bb;
168+ }
169+ }
170+ results[c + 0 ] = (f32_t )acc_scalar0;
171+ results[c + 1 ] = (f32_t )acc_scalar1;
172+ results[c + 2 ] = (f32_t )acc_scalar2;
173+ results[c + 3 ] = (f32_t )acc_scalar3;
174+ }
175+
176+ // Tail-handling: remaining 0..3 vectors
177+ for (; c < count; c++) {
178+ const int8_t * a0 = a + mapper (c, offsets) * dims;
179+ results[c] = (f32_t )vec_dot7u (a0, b, dims);
180+ }
181+ }
182+
183+ static inline int identity (const int32_t i, const int32_t * offsets) {
184+ return i;
185+ }
186+
187+ static inline int index (const int32_t i, const int32_t * offsets) {
188+ return offsets[i];
189+ }
190+
191+ EXPORT void vec_dot7u_bulk (const int8_t * a, const int8_t * b, const int32_t dims, const int32_t count, f32_t * results) {
192+ dot7u_inner_bulk<identity>(a, b, dims, NULL , count, results);
193+ }
194+
195+
196+ EXPORT void vec_dot7u_bulk_offsets (const int8_t * a, const int8_t * b, const int32_t dims, const int32_t * offsets, const int32_t count, f32_t * results) {
197+ dot7u_inner_bulk<index>(a, b, dims, offsets, count, results);
198+ }
199+
103200static inline int32_t sqr7u_inner (int8_t *a, int8_t *b, const int32_t dims) {
104201 int32x4_t acc1 = vdupq_n_s32 (0 );
105202 int32x4_t acc2 = vdupq_n_s32 (0 );
0 commit comments