Skip to content

Commit de51eea

Browse files
committed
Fix/reconcile/merge cpp files after merge
1 parent b8cbd5a commit de51eea

File tree

7 files changed

+156
-207
lines changed

7 files changed

+156
-207
lines changed

libs/simdvec/native/publish_vec_binaries.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ if [ -z "$ARTIFACTORY_API_KEY" ]; then
2020
exit 1;
2121
fi
2222

23-
VERSION="1.0.17"
23+
VERSION="1.0.18"
2424
ARTIFACTORY_REPOSITORY="${ARTIFACTORY_REPOSITORY:-https://artifactory.elastic.dev/artifactory/elasticsearch-native/}"
2525
TEMP=$(mktemp -d)
2626

libs/simdvec/native/src/vec/c/aarch64/vec_1.cpp

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ static inline int32_t dot7u_inner(const int8_t* a, const int8_t* b, const int32_
8787
return vaddvq_s32(vaddq_s32(acc5, acc6));
8888
}
8989

90-
EXPORT int32_t vec_dot7u(int8_t* a, int8_t* b, const int32_t dims) {
90+
EXPORT int32_t vec_dot7u(const int8_t* a, const int8_t* b, const int32_t dims) {
9191
int32_t res = 0;
9292
int i = 0;
9393
if (dims > DOT7U_STRIDE_BYTES_LEN) {
@@ -100,6 +100,103 @@ EXPORT int32_t vec_dot7u(int8_t* a, int8_t* b, const int32_t dims) {
100100
return res;
101101
}
102102

103+
template <int32_t(*mapper)(const int32_t, const int32_t*)>
104+
static inline void dot7u_inner_bulk(const int8_t* a, const int8_t* b, const int32_t dims, const int32_t* offsets, const int32_t count, f32_t* results) {
105+
size_t blk = dims & ~15;
106+
size_t c = 0;
107+
108+
// Process 4 vectors at a time
109+
for (; c + 3 < count; c += 4) {
110+
const int8_t* a0 = a + mapper(c, offsets) * dims;
111+
const int8_t* a1 = a + mapper(c + 1, offsets) * dims;
112+
const int8_t* a2 = a + mapper(c + 2, offsets) * dims;
113+
const int8_t* a3 = a + mapper(c + 3, offsets) * dims;
114+
115+
int32x4_t acc0 = vdupq_n_s32(0);
116+
int32x4_t acc1 = vdupq_n_s32(0);
117+
int32x4_t acc2 = vdupq_n_s32(0);
118+
int32x4_t acc3 = vdupq_n_s32(0);
119+
int32x4_t acc4 = vdupq_n_s32(0);
120+
int32x4_t acc5 = vdupq_n_s32(0);
121+
int32x4_t acc6 = vdupq_n_s32(0);
122+
int32x4_t acc7 = vdupq_n_s32(0);
123+
124+
for (size_t i = 0; i < blk; i += 16) {
125+
int8x16_t vb = vld1q_s8(b + i);
126+
127+
int8x16_t v0 = vld1q_s8(a0 + i);
128+
int16x8_t lo0 = vmull_s8(vget_low_s8(v0), vget_low_s8(vb));
129+
int16x8_t hi0 = vmull_s8(vget_high_s8(v0), vget_high_s8(vb));
130+
acc0 = vpadalq_s16(acc0, lo0);
131+
acc1 = vpadalq_s16(acc1, hi0);
132+
133+
int8x16_t v1 = vld1q_s8(a1 + i);
134+
int16x8_t lo1 = vmull_s8(vget_low_s8(v1), vget_low_s8(vb));
135+
int16x8_t hi1 = vmull_s8(vget_high_s8(v1), vget_high_s8(vb));
136+
acc2 = vpadalq_s16(acc2, lo1);
137+
acc3 = vpadalq_s16(acc3, hi1);
138+
139+
int8x16_t v2 = vld1q_s8(a2 + i);
140+
int16x8_t lo2 = vmull_s8(vget_low_s8(v2), vget_low_s8(vb));
141+
int16x8_t hi2 = vmull_s8(vget_high_s8(v2), vget_high_s8(vb));
142+
acc4 = vpadalq_s16(acc4, lo2);
143+
acc5 = vpadalq_s16(acc5, hi2);
144+
145+
int8x16_t v3 = vld1q_s8(a3 + i);
146+
int16x8_t lo3 = vmull_s8(vget_low_s8(v3), vget_low_s8(vb));
147+
int16x8_t hi3 = vmull_s8(vget_high_s8(v3), vget_high_s8(vb));
148+
acc6 = vpadalq_s16(acc6, lo3);
149+
acc7 = vpadalq_s16(acc7, hi3);
150+
}
151+
int32x4_t acc01 = vaddq_s32(acc0, acc1);
152+
int32x4_t acc23 = vaddq_s32(acc2, acc3);
153+
int32x4_t acc45 = vaddq_s32(acc4, acc5);
154+
int32x4_t acc67 = vaddq_s32(acc6, acc7);
155+
156+
int32_t acc_scalar0 = vaddvq_s32(acc01);
157+
int32_t acc_scalar1 = vaddvq_s32(acc23);
158+
int32_t acc_scalar2 = vaddvq_s32(acc45);
159+
int32_t acc_scalar3 = vaddvq_s32(acc67);
160+
if (blk != dims) {
161+
// scalar tail
162+
for (size_t t = blk; t < dims; t++) {
163+
const int8_t bb = b[t];
164+
acc_scalar0 += a0[t] * bb;
165+
acc_scalar1 += a1[t] * bb;
166+
acc_scalar2 += a2[t] * bb;
167+
acc_scalar3 += a3[t] * bb;
168+
}
169+
}
170+
results[c + 0] = (f32_t)acc_scalar0;
171+
results[c + 1] = (f32_t)acc_scalar1;
172+
results[c + 2] = (f32_t)acc_scalar2;
173+
results[c + 3] = (f32_t)acc_scalar3;
174+
}
175+
176+
// Tail-handling: remaining 0..3 vectors
177+
for (; c < count; c++) {
178+
const int8_t* a0 = a + mapper(c, offsets) * dims;
179+
results[c] = (f32_t)vec_dot7u(a0, b, dims);
180+
}
181+
}
182+
183+
static inline int identity(const int32_t i, const int32_t* offsets) {
184+
return i;
185+
}
186+
187+
static inline int index(const int32_t i, const int32_t* offsets) {
188+
return offsets[i];
189+
}
190+
191+
EXPORT void vec_dot7u_bulk(const int8_t* a, const int8_t* b, const int32_t dims, const int32_t count, f32_t* results) {
192+
dot7u_inner_bulk<identity>(a, b, dims, NULL, count, results);
193+
}
194+
195+
196+
EXPORT void vec_dot7u_bulk_offsets(const int8_t* a, const int8_t* b, const int32_t dims, const int32_t* offsets, const int32_t count, f32_t* results) {
197+
dot7u_inner_bulk<index>(a, b, dims, offsets, count, results);
198+
}
199+
103200
static inline int32_t sqr7u_inner(int8_t *a, int8_t *b, const int32_t dims) {
104201
int32x4_t acc1 = vdupq_n_s32(0);
105202
int32x4_t acc2 = vdupq_n_s32(0);

libs/simdvec/native/src/vec/c/aarch64/vec_bulk.cpp

Lines changed: 0 additions & 123 deletions
This file was deleted.

libs/simdvec/native/src/vec/c/amd64/vec_1.cpp

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ static inline int32_t dot7u_inner(const int8_t* a, const int8_t* b, const int32_
144144
return hsum_i32_8(acc1);
145145
}
146146

147-
EXPORT int32_t vec_dot7u(int8_t* a, int8_t* b, const int32_t dims) {
147+
EXPORT int32_t vec_dot7u(const int8_t* a, const int8_t* b, const int32_t dims) {
148148
int32_t res = 0;
149149
int i = 0;
150150
if (dims > STRIDE_BYTES_LEN) {
@@ -157,6 +157,49 @@ EXPORT int32_t vec_dot7u(int8_t* a, int8_t* b, const int32_t dims) {
157157
return res;
158158
}
159159

160+
template <int32_t(*mapper)(int32_t, const int32_t*)>
161+
static inline void dot7u_inner_bulk(const int8_t* a, const int8_t* b, const int32_t dims, const int32_t* offsets, const int32_t count, f32_t* results) {
162+
int32_t res = 0;
163+
if (dims > STRIDE_BYTES_LEN) {
164+
const int limit = dims & ~(STRIDE_BYTES_LEN - 1);
165+
for (int32_t c = 0; c < count; c++) {
166+
const int8_t* a0 = a + mapper(c, offsets) * dims;
167+
int i = limit;
168+
res = dot7u_inner(a, b, i);
169+
for (; i < dims; i++) {
170+
res += a0[i] * b[i];
171+
}
172+
results[c] = (f32_t)res;
173+
}
174+
} else {
175+
for (int32_t c = 0; c < count; c++) {
176+
const int8_t* a0 = a + mapper(c, offsets) * dims;
177+
res = 0;
178+
for (int32_t i = 0; i < dims; i++) {
179+
res += a0[i] * b[i];
180+
}
181+
results[c] = (f32_t)res;
182+
}
183+
}
184+
}
185+
186+
static inline int identity(const int32_t i, const int32_t* offsets) {
187+
return i;
188+
}
189+
190+
static inline int index(const int32_t i, const int32_t* offsets) {
191+
return offsets[i];
192+
}
193+
194+
EXPORT void vec_dot7u_bulk(const int8_t* a, const int8_t* b, const int32_t dims, const int32_t count, f32_t* results) {
195+
dot7u_inner_bulk<identity>(a, b, dims, NULL, count, results);
196+
}
197+
198+
199+
EXPORT void vec_dot7u_bulk_offsets(const int8_t* a, const int8_t* b, const int32_t dims, const int32_t* offsets, const int32_t count, f32_t* results) {
200+
dot7u_inner_bulk<index>(a, b, dims, offsets, count, results);
201+
}
202+
160203
static inline int32_t sqr7u_inner(int8_t *a, int8_t *b, const int32_t dims) {
161204
// Init accumulator(s) with 0
162205
__m256i acc1 = _mm256_setzero_si256();

libs/simdvec/native/src/vec/c/amd64/vec_2.cpp

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ inline __m512i fma8(__m512i acc, const int8_t* p1, const int8_t* p2) {
5757
return _mm512_add_epi32(_mm512_madd_epi16(ones, dot), acc);
5858
}
5959

60-
static inline int32_t dot7u_inner_avx512(int8_t* a, const int8_t* b, const int32_t dims) {
60+
static inline int32_t dot7u_inner_avx512(const int8_t* a, const int8_t* b, const int32_t dims) {
6161
constexpr int stride8 = 8 * STRIDE_BYTES_LEN;
6262
constexpr int stride4 = 4 * STRIDE_BYTES_LEN;
6363
const int8_t* p1 = a;
@@ -110,7 +110,7 @@ static inline int32_t dot7u_inner_avx512(int8_t* a, const int8_t* b, const int32
110110
return _mm512_reduce_add_epi32(_mm512_add_epi32(acc0, acc4));
111111
}
112112

113-
EXPORT int32_t vec_dot7u_2(int8_t* a, int8_t* b, const int32_t dims) {
113+
EXPORT int32_t vec_dot7u_2(const int8_t* a, const int8_t* b, const int32_t dims) {
114114
int32_t res = 0;
115115
int i = 0;
116116
if (dims > STRIDE_BYTES_LEN) {
@@ -123,8 +123,8 @@ EXPORT int32_t vec_dot7u_2(int8_t* a, int8_t* b, const int32_t dims) {
123123
return res;
124124
}
125125

126-
template <int(*mapper)(int, int32_t*)>
127-
static inline void dot7u_inner_bulk(int8_t* a, int8_t* b, int dims, int32_t* offsets, int count, f32_t* results) {
126+
template <int32_t(*mapper)(int32_t, const int32_t*)>
127+
static inline void dot7u_inner_bulk(const int8_t* a, const int8_t* b, const int32_t dims, const int32_t* offsets, const int32_t count, f32_t* results) {
128128
int32_t res = 0;
129129
if (dims > STRIDE_BYTES_LEN) {
130130
const int limit = dims & ~(STRIDE_BYTES_LEN - 1);
@@ -149,21 +149,20 @@ static inline void dot7u_inner_bulk(int8_t* a, int8_t* b, int dims, int32_t* off
149149
}
150150
}
151151

152-
static inline int identity(int i, int32_t* offsets) {
152+
static inline int identity(const int32_t i, const int32_t* offsets) {
153153
return i;
154154
}
155155

156-
static inline int index(int i, int32_t* offsets) {
156+
static inline int index(const int32_t i, const int32_t* offsets) {
157157
return offsets[i];
158158
}
159159

160-
extern "C"
161-
EXPORT void dot7u_bulk(int8_t* a, int8_t* b, int dims, int count, f32_t* results) {
160+
EXPORT void vec_dot7u_bulk_2(const int8_t* a, const int8_t* b, const int32_t dims, const int32_t count, f32_t* results) {
162161
dot7u_inner_bulk<identity>(a, b, dims, NULL, count, results);
163162
}
164163

165-
extern "C"
166-
EXPORT void dot7u_bulk_offsets(int8_t* a, int8_t* b, int dims, int32_t* offsets, int count, f32_t* results) {
164+
165+
EXPORT void vec_dot7u_bulk_offsets_2(const int8_t* a, const int8_t* b, const int32_t dims, const int32_t* offsets, const int32_t count, f32_t* results) {
167166
dot7u_inner_bulk<index>(a, b, dims, offsets, count, results);
168167
}
169168

0 commit comments

Comments
 (0)