1414#include < stddef.h>
1515#include < stdint.h>
1616#include < math.h>
17- #include " vec.h"
1817
19- // AVX-512 code
18+ // Force the preprocessor to pick up AVX-512 intrinsics, and the compiler to emit AVX-512 code
2019#ifdef __clang__
2120#pragma clang attribute push(__attribute__((target("arch=skylake-avx512"))), apply_to=function)
2221#elif __GNUC__
2322#pragma GCC push_options
2423#pragma GCC target ("arch=skylake-avx512")
2524#endif
2625
26+ #include " vec.h"
27+ #include " vec_common.h"
28+ #include " amd64/amd64_vec_common.h"
29+
2730// Includes for intrinsics
2831#ifdef _MSC_VER
2932#include < intrin.h>
@@ -133,42 +136,70 @@ static inline void dot7u_inner_bulk(
133136 const int32_t count,
134137 f32_t * results
135138) {
136- if (dims > STRIDE_BYTES_LEN) {
137- const int limit = dims & ~(STRIDE_BYTES_LEN - 1 );
138- for (int32_t c = 0 ; c < count; c++) {
139- const int8_t * a0 = a + (mapper (c, offsets) * pitch);
140- int i = limit;
141- int32_t res = dot7u_inner_avx512 (a0, b, i);
142- for (; i < dims; i++) {
143- res += a0[i] * b[i];
144- }
145- results[c] = (f32_t )res;
139+ const int blk = dims & ~(STRIDE_BYTES_LEN - 1 );
140+ const int lines_to_fetch = dims / CACHE_LINE_SIZE + 1 ;
141+ int c = 0 ;
142+
143+ const int8_t * a0 = safe_mapper_offset<0 , mapper>(a, pitch, offsets, count);
144+ const int8_t * a1 = safe_mapper_offset<1 , mapper>(a, pitch, offsets, count);
145+ const int8_t * a2 = safe_mapper_offset<2 , mapper>(a, pitch, offsets, count);
146+ const int8_t * a3 = safe_mapper_offset<3 , mapper>(a, pitch, offsets, count);
147+
148+ // Process a batch of 4 vectors at a time, after instructing the CPU to
149+ // prefetch the next batch.
150+ // Prefetching multiple memory locations while computing keeps the CPU
151+ // execution units busy.
152+ for (; c + 7 < count; c += 4 ) {
153+ const int8_t * next_a0 = a + mapper (c + 4 , offsets) * pitch;
154+ const int8_t * next_a1 = a + mapper (c + 5 , offsets) * pitch;
155+ const int8_t * next_a2 = a + mapper (c + 6 , offsets) * pitch;
156+ const int8_t * next_a3 = a + mapper (c + 7 , offsets) * pitch;
157+
158+ prefetch (next_a0, lines_to_fetch);
159+ prefetch (next_a1, lines_to_fetch);
160+ prefetch (next_a2, lines_to_fetch);
161+ prefetch (next_a3, lines_to_fetch);
162+
163+ int32_t res0 = 0 ;
164+ int32_t res1 = 0 ;
165+ int32_t res2 = 0 ;
166+ int32_t res3 = 0 ;
167+ int i = 0 ;
168+ if (dims > STRIDE_BYTES_LEN) {
169+ i = blk;
170+ res0 = dot7u_inner_avx512 (a0, b, i);
171+ res1 = dot7u_inner_avx512 (a1, b, i);
172+ res2 = dot7u_inner_avx512 (a2, b, i);
173+ res3 = dot7u_inner_avx512 (a3, b, i);
146174 }
147- } else {
148- for (int32_t c = 0 ; c < count; c++) {
149- const int8_t * a0 = a + (mapper (c, offsets) * pitch);
150- int32_t res = 0 ;
151- for (int32_t i = 0 ; i < dims; i++) {
152- res += a0[i] * b[i];
153- }
154- results[c] = (f32_t )res;
175+ for (; i < dims; i++) {
176+ const int8_t bb = b[i];
177+ res0 += a0[i] * bb;
178+ res1 += a1[i] * bb;
179+ res2 += a2[i] * bb;
180+ res3 += a3[i] * bb;
155181 }
182+ results[c + 0 ] = (f32_t )res0;
183+ results[c + 1 ] = (f32_t )res1;
184+ results[c + 2 ] = (f32_t )res2;
185+ results[c + 3 ] = (f32_t )res3;
186+ a0 = next_a0;
187+ a1 = next_a1;
188+ a2 = next_a2;
189+ a3 = next_a3;
156190 }
157- }
158191
159- static inline int64_t identity (const int32_t i, const int32_t * offsets) {
160- return i;
161- }
162-
163- static inline int64_t index (const int32_t i, const int32_t * offsets) {
164- return offsets[i];
192+ // Tail-handling: remaining vectors
193+ for (; c < count; c++) {
194+ const int8_t * a0 = a + mapper (c, offsets) * pitch;
195+ results[c] = (f32_t )vec_dot7u_2 (a0, b, dims);
196+ }
165197}
166198
167199EXPORT void vec_dot7u_bulk_2 (const int8_t * a, const int8_t * b, const int32_t dims, const int32_t count, f32_t * results) {
168- dot7u_inner_bulk<identity >(a, b, dims, dims, NULL , count, results);
200+ dot7u_inner_bulk<identity_mapper >(a, b, dims, dims, NULL , count, results);
169201}
170202
171-
172203EXPORT void vec_dot7u_bulk_offsets_2 (
173204 const int8_t * a,
174205 const int8_t * b,
@@ -177,7 +208,7 @@ EXPORT void vec_dot7u_bulk_offsets_2(
177208 const int32_t * offsets,
178209 const int32_t count,
179210 f32_t * results) {
180- dot7u_inner_bulk<index >(a, b, dims, pitch, offsets, count, results);
211+ dot7u_inner_bulk<array_mapper >(a, b, dims, pitch, offsets, count, results);
181212}
182213
183214template <int offsetRegs>
0 commit comments