Skip to content

Commit 5f572f0

Browse files
authored
fix: choose an appropriate SIMD implementation for aarch64 (dragonflydb#579)
1 parent 69d9ef2 commit 5f572f0

File tree

1 file changed

+78
-58
lines changed

1 file changed

+78
-58
lines changed

src/core/detail/bitpacking.cc

Lines changed: 78 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#endif
1515
#include <absl/base/internal/endian.h>
1616

17+
using namespace std;
18+
1719
namespace dfly {
1820

1921
namespace detail {
@@ -31,6 +33,75 @@ static inline uint64_t Compress8x7bit(uint64_t x) {
3133
return x;
3234
}
3335

36+
static inline pair<const char*, uint8_t*> simd_variant1_pack(const char* ascii, const char* end,
37+
uint8_t* bin) {
38+
__m128i val, rpart, lpart;
39+
40+
// Skips 8th byte (indexc 7) in the lower 8-byte part.
41+
const __m128i control = _mm_set_epi8(-1, -1, 14, 13, 12, 11, 10, 9, 8, 6, 5, 4, 3, 2, 1, 0);
42+
43+
// Based on the question I asked here: https://stackoverflow.com/q/74831843/2280111
44+
while (ascii <= end) {
45+
val = _mm_loadu_si128(reinterpret_cast<const __m128i*>(ascii));
46+
47+
/*
48+
x = ((x & 0x7F007F007F007F00) >> 1) | (x & 0x007F007F007F007F);
49+
x = ((x & 0x3FFF00003FFF0000) >> 2) | (x & 0x00003FFF00003FFF);
50+
x = ((x & 0x0FFFFFFF00000000) >> 4) | (x & 0x000000000FFFFFFF);
51+
*/
52+
53+
rpart = _mm_and_si128(val, _mm_set1_epi64x(0x007F007F007F007F));
54+
lpart = _mm_and_si128(val, _mm_set1_epi64x(0x7F007F007F007F00));
55+
val = _mm_or_si128(_mm_srli_epi64(lpart, 1), rpart);
56+
57+
rpart = _mm_and_si128(val, _mm_set1_epi64x(0x00003FFF00003FFF));
58+
lpart = _mm_and_si128(val, _mm_set1_epi64x(0x3FFF00003FFF0000));
59+
val = _mm_or_si128(_mm_srli_epi64(lpart, 2), rpart);
60+
61+
rpart = _mm_and_si128(val, _mm_set1_epi64x(0x000000000FFFFFFF));
62+
lpart = _mm_and_si128(val, _mm_set1_epi64x(0x0FFFFFFF00000000));
63+
val = _mm_or_si128(_mm_srli_epi64(lpart, 4), rpart);
64+
65+
val = _mm_shuffle_epi8(val, control);
66+
_mm_storeu_si128(reinterpret_cast<__m128i*>(bin), val);
67+
bin += 14;
68+
ascii += 16;
69+
}
70+
71+
return make_pair(ascii, bin);
72+
}
73+
74+
static inline pair<const char*, uint8_t*> simd_variant2_pack(const char* ascii, const char* end,
75+
uint8_t* bin) {
76+
// Skips 8th byte (indexc 7) in the lower 8-byte part.
77+
const __m128i control = _mm_set_epi8(-1, -1, 14, 13, 12, 11, 10, 9, 8, 6, 5, 4, 3, 2, 1, 0);
78+
79+
__m128i val, rpart, lpart;
80+
81+
// Based on the question I asked here: https://stackoverflow.com/q/74831843/2280111
82+
while (ascii <= end) {
83+
val = _mm_loadu_si128(reinterpret_cast<const __m128i*>(ascii));
84+
85+
/*
86+
x = ((x & 0x7F007F007F007F00) >> 1) | (x & 0x007F007F007F007F);
87+
x = ((x & 0x3FFF00003FFF0000) >> 2) | (x & 0x00003FFF00003FFF);
88+
x = ((x & 0x0FFFFFFF00000000) >> 4) | (x & 0x000000000FFFFFFF);
89+
*/
90+
val = _mm_maddubs_epi16(_mm_set1_epi16(0x8001), val);
91+
val = _mm_madd_epi16(_mm_set1_epi32(0x40000001), val);
92+
93+
rpart = _mm_and_si128(val, _mm_set1_epi64x(0x000000000FFFFFFF));
94+
lpart = _mm_and_si128(val, _mm_set1_epi64x(0x0FFFFFFF00000000));
95+
val = _mm_or_si128(_mm_srli_epi64(lpart, 4), rpart);
96+
97+
val = _mm_shuffle_epi8(val, control);
98+
_mm_storeu_si128(reinterpret_cast<__m128i*>(bin), val);
99+
bin += 14;
100+
ascii += 16;
101+
}
102+
return make_pair(ascii, bin);
103+
}
104+
34105
// Daniel Lemire's function validate_ascii_fast() - under Apache/MIT license.
35106
// See https://github.com/lemire/fastvalidate-utf-8/
36107
// The function returns true (1) if all chars passed in src are
@@ -103,38 +174,7 @@ void ascii_pack_simd(const char* ascii, size_t len, uint8_t* bin) {
103174
// overwrite we finish loop one iteration earlier.
104175
const char* end = ascii + len - 32;
105176

106-
// Skips 8th byte (indexc 7) in the lower 8-byte part.
107-
const __m128i control = _mm_set_epi8(-1, -1, 14, 13, 12, 11, 10, 9, 8, 6, 5, 4, 3, 2, 1, 0);
108-
109-
__m128i val, rpart, lpart;
110-
111-
// Based on the question I asked here: https://stackoverflow.com/q/74831843/2280111
112-
while (ascii <= end) {
113-
val = _mm_loadu_si128(reinterpret_cast<const __m128i*>(ascii));
114-
115-
/*
116-
x = ((x & 0x7F007F007F007F00) >> 1) | (x & 0x007F007F007F007F);
117-
x = ((x & 0x3FFF00003FFF0000) >> 2) | (x & 0x00003FFF00003FFF);
118-
x = ((x & 0x0FFFFFFF00000000) >> 4) | (x & 0x000000000FFFFFFF);
119-
*/
120-
121-
rpart = _mm_and_si128(val, _mm_set1_epi64x(0x007F007F007F007F));
122-
lpart = _mm_and_si128(val, _mm_set1_epi64x(0x7F007F007F007F00));
123-
val = _mm_or_si128(_mm_srli_epi64(lpart, 1), rpart);
124-
125-
rpart = _mm_and_si128(val, _mm_set1_epi64x(0x00003FFF00003FFF));
126-
lpart = _mm_and_si128(val, _mm_set1_epi64x(0x3FFF00003FFF0000));
127-
val = _mm_or_si128(_mm_srli_epi64(lpart, 2), rpart);
128-
129-
rpart = _mm_and_si128(val, _mm_set1_epi64x(0x000000000FFFFFFF));
130-
lpart = _mm_and_si128(val, _mm_set1_epi64x(0x0FFFFFFF00000000));
131-
val = _mm_or_si128(_mm_srli_epi64(lpart, 4), rpart);
132-
133-
val = _mm_shuffle_epi8(val, control);
134-
_mm_storeu_si128(reinterpret_cast<__m128i*>(bin), val);
135-
bin += 14;
136-
ascii += 16;
137-
}
177+
tie(ascii, bin) = simd_variant1_pack(ascii, end, bin);
138178

139179
end += 32; // Bring back end.
140180
DCHECK(ascii < end);
@@ -147,32 +187,12 @@ void ascii_pack_simd2(const char* ascii, size_t len, uint8_t* bin) {
147187
// overwrite we finish loop one iteration earlier.
148188
const char* end = ascii + len - 32;
149189

150-
// Skips 8th byte (indexc 7) in the lower 8-byte part.
151-
const __m128i control = _mm_set_epi8(-1, -1, 14, 13, 12, 11, 10, 9, 8, 6, 5, 4, 3, 2, 1, 0);
152-
153-
__m128i val, rpart, lpart;
154-
155-
// Based on the question I asked here: https://stackoverflow.com/q/74831843/2280111
156-
while (ascii <= end) {
157-
val = _mm_loadu_si128(reinterpret_cast<const __m128i*>(ascii));
158-
159-
/*
160-
x = ((x & 0x7F007F007F007F00) >> 1) | (x & 0x007F007F007F007F);
161-
x = ((x & 0x3FFF00003FFF0000) >> 2) | (x & 0x00003FFF00003FFF);
162-
x = ((x & 0x0FFFFFFF00000000) >> 4) | (x & 0x000000000FFFFFFF);
163-
*/
164-
val = _mm_maddubs_epi16(_mm_set1_epi16(0x8001), val);
165-
val = _mm_madd_epi16(_mm_set1_epi32(0x40000001), val);
166-
167-
rpart = _mm_and_si128(val, _mm_set1_epi64x(0x000000000FFFFFFF));
168-
lpart = _mm_and_si128(val, _mm_set1_epi64x(0x0FFFFFFF00000000));
169-
val = _mm_or_si128(_mm_srli_epi64(lpart, 4), rpart);
170-
171-
val = _mm_shuffle_epi8(val, control);
172-
_mm_storeu_si128(reinterpret_cast<__m128i*>(bin), val);
173-
bin += 14;
174-
ascii += 16;
175-
}
190+
// on arm var
191+
#if defined(__aarch64__)
192+
tie(ascii, bin) = simd_variant1_pack(ascii, end, bin);
193+
#else
194+
tie(ascii, bin) = simd_variant2_pack(ascii, end, bin);
195+
#endif
176196

177197
end += 32; // Bring back end.
178198
DCHECK(ascii < end);

0 commit comments

Comments
 (0)