Skip to content

Commit d48ac77

Browse files
committed
Added AVX-512 optimizations
1 parent cd2470f commit d48ac77

File tree

2 files changed

+76
-159
lines changed

2 files changed

+76
-159
lines changed

src/simd/x86/avx2_ops.c

Lines changed: 6 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -407,42 +407,22 @@ void carquet_avx2_gather_i64(const int64_t* dict, const uint32_t* indices,
407407

408408
/**
409409
* Gather float values from dictionary using AVX2 gather instructions.
410+
* Note: float and int32 are both 4 bytes, so we reuse gather_i32 via cast.
410411
*/
411412
void carquet_avx2_gather_float(const float* dict, const uint32_t* indices,
412413
int64_t count, float* output) {
413-
int64_t i = 0;
414-
415-
/* Process 8 at a time using AVX2 gather */
416-
for (; i + 8 <= count; i += 8) {
417-
__m256i idx = _mm256_loadu_si256((const __m256i*)(indices + i));
418-
__m256 result = _mm256_i32gather_ps(dict, idx, 4); /* Scale = 4 bytes per float */
419-
_mm256_storeu_ps(output + i, result);
420-
}
421-
422-
/* Handle remaining */
423-
for (; i < count; i++) {
424-
output[i] = dict[indices[i]];
425-
}
414+
/* Data movement doesn't care about type - reuse int32 implementation */
415+
carquet_avx2_gather_i32((const int32_t*)dict, indices, count, (int32_t*)output);
426416
}
427417

428418
/**
429419
* Gather double values from dictionary using AVX2 gather instructions.
420+
* Note: double and int64 are both 8 bytes, so we reuse gather_i64 via cast.
430421
*/
431422
void carquet_avx2_gather_double(const double* dict, const uint32_t* indices,
432423
int64_t count, double* output) {
433-
int64_t i = 0;
434-
435-
/* Process 4 at a time using AVX2 gather */
436-
for (; i + 4 <= count; i += 4) {
437-
__m128i idx = _mm_loadu_si128((const __m128i*)(indices + i));
438-
__m256d result = _mm256_i32gather_pd(dict, idx, 8); /* Scale = 8 bytes per double */
439-
_mm256_storeu_pd(output + i, result);
440-
}
441-
442-
/* Handle remaining */
443-
for (; i < count; i++) {
444-
output[i] = dict[indices[i]];
445-
}
424+
/* Data movement doesn't care about type - reuse int64 implementation */
425+
carquet_avx2_gather_i64((const int64_t*)dict, indices, count, (int64_t*)output);
446426
}
447427

448428
/* ============================================================================

src/simd/x86/avx512_ops.c

Lines changed: 70 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -116,95 +116,57 @@ void carquet_avx512_byte_stream_split_encode_float(
116116
int64_t i = 0;
117117

118118
#ifdef __AVX512VBMI__
119-
/* Permutation indices to gather bytes by position across 16 floats */
120-
/* For 16 floats (64 bytes), gather all byte 0s, then byte 1s, etc. */
121-
const __m512i perm_b0 = _mm512_set_epi8(
122-
60, 56, 52, 48, 44, 40, 36, 32, 28, 24, 20, 16, 12, 8, 4, 0,
123-
60, 56, 52, 48, 44, 40, 36, 32, 28, 24, 20, 16, 12, 8, 4, 0,
124-
60, 56, 52, 48, 44, 40, 36, 32, 28, 24, 20, 16, 12, 8, 4, 0,
125-
60, 56, 52, 48, 44, 40, 36, 32, 28, 24, 20, 16, 12, 8, 4, 0);
126-
const __m512i perm_b1 = _mm512_set_epi8(
127-
61, 57, 53, 49, 45, 41, 37, 33, 29, 25, 21, 17, 13, 9, 5, 1,
128-
61, 57, 53, 49, 45, 41, 37, 33, 29, 25, 21, 17, 13, 9, 5, 1,
129-
61, 57, 53, 49, 45, 41, 37, 33, 29, 25, 21, 17, 13, 9, 5, 1,
130-
61, 57, 53, 49, 45, 41, 37, 33, 29, 25, 21, 17, 13, 9, 5, 1);
131-
const __m512i perm_b2 = _mm512_set_epi8(
132-
62, 58, 54, 50, 46, 42, 38, 34, 30, 26, 22, 18, 14, 10, 6, 2,
133-
62, 58, 54, 50, 46, 42, 38, 34, 30, 26, 22, 18, 14, 10, 6, 2,
134-
62, 58, 54, 50, 46, 42, 38, 34, 30, 26, 22, 18, 14, 10, 6, 2,
135-
62, 58, 54, 50, 46, 42, 38, 34, 30, 26, 22, 18, 14, 10, 6, 2);
136-
const __m512i perm_b3 = _mm512_set_epi8(
137-
63, 59, 55, 51, 47, 43, 39, 35, 31, 27, 23, 19, 15, 11, 7, 3,
138-
63, 59, 55, 51, 47, 43, 39, 35, 31, 27, 23, 19, 15, 11, 7, 3,
139-
63, 59, 55, 51, 47, 43, 39, 35, 31, 27, 23, 19, 15, 11, 7, 3,
140-
63, 59, 55, 51, 47, 43, 39, 35, 31, 27, 23, 19, 15, 11, 7, 3);
119+
/* Single permutation that places all 4 byte streams in the 4 128-bit lanes:
120+
* Lane 0 (bits 0-127): byte 0 from each of 16 floats
121+
* Lane 1 (bits 128-255): byte 1 from each of 16 floats
122+
* Lane 2 (bits 256-383): byte 2 from each of 16 floats
123+
* Lane 3 (bits 384-511): byte 3 from each of 16 floats
124+
*/
125+
const __m512i perm_all = _mm512_set_epi8(
126+
63, 59, 55, 51, 47, 43, 39, 35, 31, 27, 23, 19, 15, 11, 7, 3, /* byte 3s */
127+
62, 58, 54, 50, 46, 42, 38, 34, 30, 26, 22, 18, 14, 10, 6, 2, /* byte 2s */
128+
61, 57, 53, 49, 45, 41, 37, 33, 29, 25, 21, 17, 13, 9, 5, 1, /* byte 1s */
129+
60, 56, 52, 48, 44, 40, 36, 32, 28, 24, 20, 16, 12, 8, 4, 0); /* byte 0s */
141130

142131
for (; i + 16 <= count; i += 16) {
143132
__m512i v = _mm512_loadu_si512((const __m512i*)(src + i * 4));
144133

145-
/* Permute to gather bytes by position */
146-
__m512i b0 = _mm512_permutexvar_epi8(perm_b0, v);
147-
__m512i b1 = _mm512_permutexvar_epi8(perm_b1, v);
148-
__m512i b2 = _mm512_permutexvar_epi8(perm_b2, v);
149-
__m512i b3 = _mm512_permutexvar_epi8(perm_b3, v);
150-
151-
/* Store 16 bytes to each stream (only lower 128 bits valid) */
152-
_mm_storeu_si128((__m128i*)(output + 0 * count + i), _mm512_castsi512_si128(b0));
153-
_mm_storeu_si128((__m128i*)(output + 1 * count + i), _mm512_castsi512_si128(b1));
154-
_mm_storeu_si128((__m128i*)(output + 2 * count + i), _mm512_castsi512_si128(b2));
155-
_mm_storeu_si128((__m128i*)(output + 3 * count + i), _mm512_castsi512_si128(b3));
134+
/* Single permutation gathers all 4 streams */
135+
__m512i transposed = _mm512_permutexvar_epi8(perm_all, v);
136+
137+
/* Extract and store each 128-bit lane to its stream */
138+
_mm_storeu_si128((__m128i*)(output + 0 * count + i), _mm512_castsi512_si128(transposed));
139+
_mm_storeu_si128((__m128i*)(output + 1 * count + i), _mm512_extracti32x4_epi32(transposed, 1));
140+
_mm_storeu_si128((__m128i*)(output + 2 * count + i), _mm512_extracti32x4_epi32(transposed, 2));
141+
_mm_storeu_si128((__m128i*)(output + 3 * count + i), _mm512_extracti32x4_epi32(transposed, 3));
156142
}
157143
#else
158-
/* Fallback without VBMI: use shuffle approach */
144+
/* Fallback without VBMI: use shuffle + permutexvar approach
145+
* Step 1: shuffle_epi8 transposes within each 128-bit lane (4 floats -> 4 bytes per stream)
146+
* Step 2: permutexvar_epi32 rearranges dwords to group all byte 0s, byte 1s, etc.
147+
*/
148+
const __m512i intra_lane_shuf = _mm512_set_epi8(
149+
15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0,
150+
15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0,
151+
15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0,
152+
15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0);
153+
const __m512i cross_lane_perm = _mm512_set_epi32(
154+
15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0);
155+
159156
for (; i + 16 <= count; i += 16) {
160-
/* Load as 4 128-bit chunks */
161-
__m128i v0 = _mm_loadu_si128((const __m128i*)(src + i * 4 + 0));
162-
__m128i v1 = _mm_loadu_si128((const __m128i*)(src + i * 4 + 16));
163-
__m128i v2 = _mm_loadu_si128((const __m128i*)(src + i * 4 + 32));
164-
__m128i v3 = _mm_loadu_si128((const __m128i*)(src + i * 4 + 48));
165-
166-
/* Transpose 4x4 blocks of bytes using unpack operations */
167-
__m128i t0 = _mm_unpacklo_epi8(v0, v1); /* a0b0a1b1... */
168-
__m128i t1 = _mm_unpackhi_epi8(v0, v1);
169-
__m128i t2 = _mm_unpacklo_epi8(v2, v3);
170-
__m128i t3 = _mm_unpackhi_epi8(v2, v3);
171-
172-
__m128i u0 = _mm_unpacklo_epi8(t0, t2);
173-
__m128i u1 = _mm_unpackhi_epi8(t0, t2);
174-
__m128i u2 = _mm_unpacklo_epi8(t1, t3);
175-
__m128i u3 = _mm_unpackhi_epi8(t1, t3);
176-
177-
/* Extract and store byte streams using shuffle */
178-
const __m128i shuf_b0 = _mm_set_epi8(-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, 12,8,4,0);
179-
const __m128i shuf_b1 = _mm_set_epi8(-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, 13,9,5,1);
180-
const __m128i shuf_b2 = _mm_set_epi8(-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, 14,10,6,2);
181-
const __m128i shuf_b3 = _mm_set_epi8(-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, 15,11,7,3);
182-
183-
/* Extract 4 bytes from each of 4 chunks = 16 bytes per stream */
184-
uint32_t* out0 = (uint32_t*)(output + 0 * count + i);
185-
uint32_t* out1 = (uint32_t*)(output + 1 * count + i);
186-
uint32_t* out2 = (uint32_t*)(output + 2 * count + i);
187-
uint32_t* out3 = (uint32_t*)(output + 3 * count + i);
188-
189-
out0[0] = _mm_extract_epi32(_mm_shuffle_epi8(u0, shuf_b0), 0);
190-
out0[1] = _mm_extract_epi32(_mm_shuffle_epi8(u1, shuf_b0), 0);
191-
out0[2] = _mm_extract_epi32(_mm_shuffle_epi8(u2, shuf_b0), 0);
192-
out0[3] = _mm_extract_epi32(_mm_shuffle_epi8(u3, shuf_b0), 0);
193-
194-
out1[0] = _mm_extract_epi32(_mm_shuffle_epi8(u0, shuf_b1), 0);
195-
out1[1] = _mm_extract_epi32(_mm_shuffle_epi8(u1, shuf_b1), 0);
196-
out1[2] = _mm_extract_epi32(_mm_shuffle_epi8(u2, shuf_b1), 0);
197-
out1[3] = _mm_extract_epi32(_mm_shuffle_epi8(u3, shuf_b1), 0);
198-
199-
out2[0] = _mm_extract_epi32(_mm_shuffle_epi8(u0, shuf_b2), 0);
200-
out2[1] = _mm_extract_epi32(_mm_shuffle_epi8(u1, shuf_b2), 0);
201-
out2[2] = _mm_extract_epi32(_mm_shuffle_epi8(u2, shuf_b2), 0);
202-
out2[3] = _mm_extract_epi32(_mm_shuffle_epi8(u3, shuf_b2), 0);
203-
204-
out3[0] = _mm_extract_epi32(_mm_shuffle_epi8(u0, shuf_b3), 0);
205-
out3[1] = _mm_extract_epi32(_mm_shuffle_epi8(u1, shuf_b3), 0);
206-
out3[2] = _mm_extract_epi32(_mm_shuffle_epi8(u2, shuf_b3), 0);
207-
out3[3] = _mm_extract_epi32(_mm_shuffle_epi8(u3, shuf_b3), 0);
157+
__m512i v = _mm512_loadu_si512((const __m512i*)(src + i * 4));
158+
159+
/* Transpose within each 128-bit lane */
160+
__m512i shuffled = _mm512_shuffle_epi8(v, intra_lane_shuf);
161+
162+
/* Rearrange dwords across lanes to group streams */
163+
__m512i transposed = _mm512_permutexvar_epi32(cross_lane_perm, shuffled);
164+
165+
/* Extract and store each 128-bit lane to its stream */
166+
_mm_storeu_si128((__m128i*)(output + 0 * count + i), _mm512_castsi512_si128(transposed));
167+
_mm_storeu_si128((__m128i*)(output + 1 * count + i), _mm512_extracti32x4_epi32(transposed, 1));
168+
_mm_storeu_si128((__m128i*)(output + 2 * count + i), _mm512_extracti32x4_epi32(transposed, 2));
169+
_mm_storeu_si128((__m128i*)(output + 3 * count + i), _mm512_extracti32x4_epi32(transposed, 3));
208170
}
209171
#endif
210172

@@ -400,49 +362,22 @@ void carquet_avx512_gather_i64(const int64_t* dict, const uint32_t* indices,
400362

401363
/**
402364
* Gather float values from dictionary using AVX-512 gather instructions.
365+
* Note: float and int32 are both 4 bytes, so we reuse gather_i32 via cast.
403366
*/
404367
void carquet_avx512_gather_float(const float* dict, const uint32_t* indices,
405368
int64_t count, float* output) {
406-
int64_t i = 0;
407-
408-
/* Process 16 at a time using AVX-512 gather */
409-
for (; i + 16 <= count; i += 16) {
410-
__m512i idx = _mm512_loadu_si512((const __m512i*)(indices + i));
411-
__m512 result = _mm512_i32gather_ps(idx, dict, 4);
412-
_mm512_storeu_ps(output + i, result);
413-
}
414-
415-
/* Handle remaining with AVX2 */
416-
for (; i + 8 <= count; i += 8) {
417-
__m256i idx = _mm256_loadu_si256((const __m256i*)(indices + i));
418-
__m256 result = _mm256_i32gather_ps(dict, idx, 4);
419-
_mm256_storeu_ps(output + i, result);
420-
}
421-
422-
/* Handle remaining */
423-
for (; i < count; i++) {
424-
output[i] = dict[indices[i]];
425-
}
369+
/* Data movement doesn't care about type - reuse int32 implementation */
370+
carquet_avx512_gather_i32((const int32_t*)dict, indices, count, (int32_t*)output);
426371
}
427372

428373
/**
429374
* Gather double values from dictionary using AVX-512 gather instructions.
375+
* Note: double and int64 are both 8 bytes, so we reuse gather_i64 via cast.
430376
*/
431377
void carquet_avx512_gather_double(const double* dict, const uint32_t* indices,
432378
int64_t count, double* output) {
433-
int64_t i = 0;
434-
435-
/* Process 8 at a time using AVX-512 gather */
436-
for (; i + 8 <= count; i += 8) {
437-
__m256i idx = _mm256_loadu_si256((const __m256i*)(indices + i));
438-
__m512d result = _mm512_i32gather_pd(idx, dict, 8);
439-
_mm512_storeu_pd(output + i, result);
440-
}
441-
442-
/* Handle remaining */
443-
for (; i < count; i++) {
444-
output[i] = dict[indices[i]];
445-
}
379+
/* Data movement doesn't care about type - reuse int64 implementation */
380+
carquet_avx512_gather_i64((const int64_t*)dict, indices, count, (int64_t*)output);
446381
}
447382

448383
/* ============================================================================
@@ -558,13 +493,8 @@ void carquet_avx512_unpack_bools(const uint8_t* input, uint8_t* output, int64_t
558493
uint64_t packed;
559494
memcpy(&packed, input + byte_idx, 8);
560495

561-
/* Convert to mask */
562-
__mmask64 mask = (__mmask64)packed;
563-
564-
/* Create result: 1 where mask bit is set, 0 otherwise */
565-
__m512i ones = _mm512_set1_epi8(1);
566-
__m512i zeros = _mm512_setzero_si512();
567-
__m512i result = _mm512_mask_mov_epi8(zeros, mask, ones);
496+
/* Convert to mask and create result with maskz_set1 (1 where set, 0 otherwise) */
497+
__m512i result = _mm512_maskz_set1_epi8((__mmask64)packed, 1);
568498

569499
_mm512_storeu_si512((__m512i*)(output + i), result);
570500
}
@@ -587,23 +517,30 @@ void carquet_avx512_pack_bools(const uint8_t* input, uint8_t* output, int64_t co
587517
for (; i + 64 <= count; i += 64) {
588518
__m512i bools = _mm512_loadu_si512((const __m512i*)(input + i));
589519

590-
/* Compare with zero to get mask */
591-
__mmask64 mask = _mm512_cmpneq_epi8_mask(bools, _mm512_setzero_si512());
520+
/* Use test_epi8_mask: bit is set if (a & b) != 0, i.e., if bool is non-zero */
521+
__mmask64 mask = _mm512_test_epi8_mask(bools, bools);
592522

593523
/* Store mask as 8 bytes */
594524
uint64_t packed = (uint64_t)mask;
595525
memcpy(output + i / 8, &packed, 8);
596526
}
597527

598-
/* Handle remaining */
599-
for (; i < count; i += 8) {
600-
uint8_t byte = 0;
601-
for (int64_t j = 0; j < 8 && i + j < count; j++) {
602-
if (input[i + j]) {
603-
byte |= (1 << j);
604-
}
605-
}
606-
output[i / 8] = byte;
528+
/* Handle remaining elements with masked load */
529+
if (i < count) {
530+
int64_t remaining = count - i;
531+
/* Create mask for remaining elements: set bits 0..(remaining-1) */
532+
__mmask64 load_mask = (remaining >= 64) ? ~0ULL : ((1ULL << remaining) - 1);
533+
534+
/* Masked load zeros out elements beyond the mask */
535+
__m512i bools = _mm512_maskz_loadu_epi8(load_mask, input + i);
536+
537+
/* Test for non-zero values */
538+
__mmask64 result_mask = _mm512_test_epi8_mask(bools, bools);
539+
540+
/* Write only the bytes we need */
541+
int64_t bytes_to_write = (remaining + 7) / 8;
542+
uint64_t packed = (uint64_t)result_mask;
543+
memcpy(output + i / 8, &packed, (size_t)bytes_to_write);
607544
}
608545
}
609546

0 commit comments

Comments
 (0)