@@ -116,95 +116,57 @@ void carquet_avx512_byte_stream_split_encode_float(
116116 int64_t i = 0 ;
117117
118118#ifdef __AVX512VBMI__
119- /* Permutation indices to gather bytes by position across 16 floats */
120- /* For 16 floats (64 bytes), gather all byte 0s, then byte 1s, etc. */
121- const __m512i perm_b0 = _mm512_set_epi8 (
122- 60 , 56 , 52 , 48 , 44 , 40 , 36 , 32 , 28 , 24 , 20 , 16 , 12 , 8 , 4 , 0 ,
123- 60 , 56 , 52 , 48 , 44 , 40 , 36 , 32 , 28 , 24 , 20 , 16 , 12 , 8 , 4 , 0 ,
124- 60 , 56 , 52 , 48 , 44 , 40 , 36 , 32 , 28 , 24 , 20 , 16 , 12 , 8 , 4 , 0 ,
125- 60 , 56 , 52 , 48 , 44 , 40 , 36 , 32 , 28 , 24 , 20 , 16 , 12 , 8 , 4 , 0 );
126- const __m512i perm_b1 = _mm512_set_epi8 (
127- 61 , 57 , 53 , 49 , 45 , 41 , 37 , 33 , 29 , 25 , 21 , 17 , 13 , 9 , 5 , 1 ,
128- 61 , 57 , 53 , 49 , 45 , 41 , 37 , 33 , 29 , 25 , 21 , 17 , 13 , 9 , 5 , 1 ,
129- 61 , 57 , 53 , 49 , 45 , 41 , 37 , 33 , 29 , 25 , 21 , 17 , 13 , 9 , 5 , 1 ,
130- 61 , 57 , 53 , 49 , 45 , 41 , 37 , 33 , 29 , 25 , 21 , 17 , 13 , 9 , 5 , 1 );
131- const __m512i perm_b2 = _mm512_set_epi8 (
132- 62 , 58 , 54 , 50 , 46 , 42 , 38 , 34 , 30 , 26 , 22 , 18 , 14 , 10 , 6 , 2 ,
133- 62 , 58 , 54 , 50 , 46 , 42 , 38 , 34 , 30 , 26 , 22 , 18 , 14 , 10 , 6 , 2 ,
134- 62 , 58 , 54 , 50 , 46 , 42 , 38 , 34 , 30 , 26 , 22 , 18 , 14 , 10 , 6 , 2 ,
135- 62 , 58 , 54 , 50 , 46 , 42 , 38 , 34 , 30 , 26 , 22 , 18 , 14 , 10 , 6 , 2 );
136- const __m512i perm_b3 = _mm512_set_epi8 (
137- 63 , 59 , 55 , 51 , 47 , 43 , 39 , 35 , 31 , 27 , 23 , 19 , 15 , 11 , 7 , 3 ,
138- 63 , 59 , 55 , 51 , 47 , 43 , 39 , 35 , 31 , 27 , 23 , 19 , 15 , 11 , 7 , 3 ,
139- 63 , 59 , 55 , 51 , 47 , 43 , 39 , 35 , 31 , 27 , 23 , 19 , 15 , 11 , 7 , 3 ,
140- 63 , 59 , 55 , 51 , 47 , 43 , 39 , 35 , 31 , 27 , 23 , 19 , 15 , 11 , 7 , 3 );
119+ /* Single permutation that places all 4 byte streams in the 4 128-bit lanes:
120+ * Lane 0 (bits 0-127): byte 0 from each of 16 floats
121+ * Lane 1 (bits 128-255): byte 1 from each of 16 floats
122+ * Lane 2 (bits 256-383): byte 2 from each of 16 floats
123+ * Lane 3 (bits 384-511): byte 3 from each of 16 floats
124+ */
125+ const __m512i perm_all = _mm512_set_epi8 (
126+ 63 , 59 , 55 , 51 , 47 , 43 , 39 , 35 , 31 , 27 , 23 , 19 , 15 , 11 , 7 , 3 , /* byte 3s */
127+ 62 , 58 , 54 , 50 , 46 , 42 , 38 , 34 , 30 , 26 , 22 , 18 , 14 , 10 , 6 , 2 , /* byte 2s */
128+ 61 , 57 , 53 , 49 , 45 , 41 , 37 , 33 , 29 , 25 , 21 , 17 , 13 , 9 , 5 , 1 , /* byte 1s */
129+ 60 , 56 , 52 , 48 , 44 , 40 , 36 , 32 , 28 , 24 , 20 , 16 , 12 , 8 , 4 , 0 ); /* byte 0s */
141130
142131 for (; i + 16 <= count ; i += 16 ) {
143132 __m512i v = _mm512_loadu_si512 ((const __m512i * )(src + i * 4 ));
144133
145- /* Permute to gather bytes by position */
146- __m512i b0 = _mm512_permutexvar_epi8 (perm_b0 , v );
147- __m512i b1 = _mm512_permutexvar_epi8 (perm_b1 , v );
148- __m512i b2 = _mm512_permutexvar_epi8 (perm_b2 , v );
149- __m512i b3 = _mm512_permutexvar_epi8 (perm_b3 , v );
150-
151- /* Store 16 bytes to each stream (only lower 128 bits valid) */
152- _mm_storeu_si128 ((__m128i * )(output + 0 * count + i ), _mm512_castsi512_si128 (b0 ));
153- _mm_storeu_si128 ((__m128i * )(output + 1 * count + i ), _mm512_castsi512_si128 (b1 ));
154- _mm_storeu_si128 ((__m128i * )(output + 2 * count + i ), _mm512_castsi512_si128 (b2 ));
155- _mm_storeu_si128 ((__m128i * )(output + 3 * count + i ), _mm512_castsi512_si128 (b3 ));
134+ /* Single permutation gathers all 4 streams */
135+ __m512i transposed = _mm512_permutexvar_epi8 (perm_all , v );
136+
137+ /* Extract and store each 128-bit lane to its stream */
138+ _mm_storeu_si128 ((__m128i * )(output + 0 * count + i ), _mm512_castsi512_si128 (transposed ));
139+ _mm_storeu_si128 ((__m128i * )(output + 1 * count + i ), _mm512_extracti32x4_epi32 (transposed , 1 ));
140+ _mm_storeu_si128 ((__m128i * )(output + 2 * count + i ), _mm512_extracti32x4_epi32 (transposed , 2 ));
141+ _mm_storeu_si128 ((__m128i * )(output + 3 * count + i ), _mm512_extracti32x4_epi32 (transposed , 3 ));
156142 }
157143#else
158- /* Fallback without VBMI: use shuffle approach */
144+ /* Fallback without VBMI: use shuffle + permutexvar approach
145+ * Step 1: shuffle_epi8 transposes within each 128-bit lane (4 floats -> 4 bytes per stream)
146+ * Step 2: permutexvar_epi32 rearranges dwords to group all byte 0s, byte 1s, etc.
147+ */
148+ const __m512i intra_lane_shuf = _mm512_set_epi8 (
149+ 15 , 11 , 7 , 3 , 14 , 10 , 6 , 2 , 13 , 9 , 5 , 1 , 12 , 8 , 4 , 0 ,
150+ 15 , 11 , 7 , 3 , 14 , 10 , 6 , 2 , 13 , 9 , 5 , 1 , 12 , 8 , 4 , 0 ,
151+ 15 , 11 , 7 , 3 , 14 , 10 , 6 , 2 , 13 , 9 , 5 , 1 , 12 , 8 , 4 , 0 ,
152+ 15 , 11 , 7 , 3 , 14 , 10 , 6 , 2 , 13 , 9 , 5 , 1 , 12 , 8 , 4 , 0 );
153+ const __m512i cross_lane_perm = _mm512_set_epi32 (
154+ 15 , 11 , 7 , 3 , 14 , 10 , 6 , 2 , 13 , 9 , 5 , 1 , 12 , 8 , 4 , 0 );
155+
159156 for (; i + 16 <= count ; i += 16 ) {
160- /* Load as 4 128-bit chunks */
161- __m128i v0 = _mm_loadu_si128 ((const __m128i * )(src + i * 4 + 0 ));
162- __m128i v1 = _mm_loadu_si128 ((const __m128i * )(src + i * 4 + 16 ));
163- __m128i v2 = _mm_loadu_si128 ((const __m128i * )(src + i * 4 + 32 ));
164- __m128i v3 = _mm_loadu_si128 ((const __m128i * )(src + i * 4 + 48 ));
165-
166- /* Transpose 4x4 blocks of bytes using unpack operations */
167- __m128i t0 = _mm_unpacklo_epi8 (v0 , v1 ); /* a0b0a1b1... */
168- __m128i t1 = _mm_unpackhi_epi8 (v0 , v1 );
169- __m128i t2 = _mm_unpacklo_epi8 (v2 , v3 );
170- __m128i t3 = _mm_unpackhi_epi8 (v2 , v3 );
171-
172- __m128i u0 = _mm_unpacklo_epi8 (t0 , t2 );
173- __m128i u1 = _mm_unpackhi_epi8 (t0 , t2 );
174- __m128i u2 = _mm_unpacklo_epi8 (t1 , t3 );
175- __m128i u3 = _mm_unpackhi_epi8 (t1 , t3 );
176-
177- /* Extract and store byte streams using shuffle */
178- const __m128i shuf_b0 = _mm_set_epi8 (-1 ,-1 ,-1 ,-1 ,-1 ,-1 ,-1 ,-1 ,-1 ,-1 ,-1 ,-1 , 12 ,8 ,4 ,0 );
179- const __m128i shuf_b1 = _mm_set_epi8 (-1 ,-1 ,-1 ,-1 ,-1 ,-1 ,-1 ,-1 ,-1 ,-1 ,-1 ,-1 , 13 ,9 ,5 ,1 );
180- const __m128i shuf_b2 = _mm_set_epi8 (-1 ,-1 ,-1 ,-1 ,-1 ,-1 ,-1 ,-1 ,-1 ,-1 ,-1 ,-1 , 14 ,10 ,6 ,2 );
181- const __m128i shuf_b3 = _mm_set_epi8 (-1 ,-1 ,-1 ,-1 ,-1 ,-1 ,-1 ,-1 ,-1 ,-1 ,-1 ,-1 , 15 ,11 ,7 ,3 );
182-
183- /* Extract 4 bytes from each of 4 chunks = 16 bytes per stream */
184- uint32_t * out0 = (uint32_t * )(output + 0 * count + i );
185- uint32_t * out1 = (uint32_t * )(output + 1 * count + i );
186- uint32_t * out2 = (uint32_t * )(output + 2 * count + i );
187- uint32_t * out3 = (uint32_t * )(output + 3 * count + i );
188-
189- out0 [0 ] = _mm_extract_epi32 (_mm_shuffle_epi8 (u0 , shuf_b0 ), 0 );
190- out0 [1 ] = _mm_extract_epi32 (_mm_shuffle_epi8 (u1 , shuf_b0 ), 0 );
191- out0 [2 ] = _mm_extract_epi32 (_mm_shuffle_epi8 (u2 , shuf_b0 ), 0 );
192- out0 [3 ] = _mm_extract_epi32 (_mm_shuffle_epi8 (u3 , shuf_b0 ), 0 );
193-
194- out1 [0 ] = _mm_extract_epi32 (_mm_shuffle_epi8 (u0 , shuf_b1 ), 0 );
195- out1 [1 ] = _mm_extract_epi32 (_mm_shuffle_epi8 (u1 , shuf_b1 ), 0 );
196- out1 [2 ] = _mm_extract_epi32 (_mm_shuffle_epi8 (u2 , shuf_b1 ), 0 );
197- out1 [3 ] = _mm_extract_epi32 (_mm_shuffle_epi8 (u3 , shuf_b1 ), 0 );
198-
199- out2 [0 ] = _mm_extract_epi32 (_mm_shuffle_epi8 (u0 , shuf_b2 ), 0 );
200- out2 [1 ] = _mm_extract_epi32 (_mm_shuffle_epi8 (u1 , shuf_b2 ), 0 );
201- out2 [2 ] = _mm_extract_epi32 (_mm_shuffle_epi8 (u2 , shuf_b2 ), 0 );
202- out2 [3 ] = _mm_extract_epi32 (_mm_shuffle_epi8 (u3 , shuf_b2 ), 0 );
203-
204- out3 [0 ] = _mm_extract_epi32 (_mm_shuffle_epi8 (u0 , shuf_b3 ), 0 );
205- out3 [1 ] = _mm_extract_epi32 (_mm_shuffle_epi8 (u1 , shuf_b3 ), 0 );
206- out3 [2 ] = _mm_extract_epi32 (_mm_shuffle_epi8 (u2 , shuf_b3 ), 0 );
207- out3 [3 ] = _mm_extract_epi32 (_mm_shuffle_epi8 (u3 , shuf_b3 ), 0 );
157+ __m512i v = _mm512_loadu_si512 ((const __m512i * )(src + i * 4 ));
158+
159+ /* Transpose within each 128-bit lane */
160+ __m512i shuffled = _mm512_shuffle_epi8 (v , intra_lane_shuf );
161+
162+ /* Rearrange dwords across lanes to group streams */
163+ __m512i transposed = _mm512_permutexvar_epi32 (cross_lane_perm , shuffled );
164+
165+ /* Extract and store each 128-bit lane to its stream */
166+ _mm_storeu_si128 ((__m128i * )(output + 0 * count + i ), _mm512_castsi512_si128 (transposed ));
167+ _mm_storeu_si128 ((__m128i * )(output + 1 * count + i ), _mm512_extracti32x4_epi32 (transposed , 1 ));
168+ _mm_storeu_si128 ((__m128i * )(output + 2 * count + i ), _mm512_extracti32x4_epi32 (transposed , 2 ));
169+ _mm_storeu_si128 ((__m128i * )(output + 3 * count + i ), _mm512_extracti32x4_epi32 (transposed , 3 ));
208170 }
209171#endif
210172
@@ -400,49 +362,22 @@ void carquet_avx512_gather_i64(const int64_t* dict, const uint32_t* indices,
400362
401363/**
402364 * Gather float values from dictionary using AVX-512 gather instructions.
365+ * Note: float and int32 are both 4 bytes, so we reuse gather_i32 via cast.
403366 */
404367void carquet_avx512_gather_float (const float * dict , const uint32_t * indices ,
405368 int64_t count , float * output ) {
406- int64_t i = 0 ;
407-
408- /* Process 16 at a time using AVX-512 gather */
409- for (; i + 16 <= count ; i += 16 ) {
410- __m512i idx = _mm512_loadu_si512 ((const __m512i * )(indices + i ));
411- __m512 result = _mm512_i32gather_ps (idx , dict , 4 );
412- _mm512_storeu_ps (output + i , result );
413- }
414-
415- /* Handle remaining with AVX2 */
416- for (; i + 8 <= count ; i += 8 ) {
417- __m256i idx = _mm256_loadu_si256 ((const __m256i * )(indices + i ));
418- __m256 result = _mm256_i32gather_ps (dict , idx , 4 );
419- _mm256_storeu_ps (output + i , result );
420- }
421-
422- /* Handle remaining */
423- for (; i < count ; i ++ ) {
424- output [i ] = dict [indices [i ]];
425- }
369+ /* Data movement doesn't care about type - reuse int32 implementation */
370+ carquet_avx512_gather_i32 ((const int32_t * )dict , indices , count , (int32_t * )output );
426371}
427372
428373/**
429374 * Gather double values from dictionary using AVX-512 gather instructions.
375+ * Note: double and int64 are both 8 bytes, so we reuse gather_i64 via cast.
430376 */
431377void carquet_avx512_gather_double (const double * dict , const uint32_t * indices ,
432378 int64_t count , double * output ) {
433- int64_t i = 0 ;
434-
435- /* Process 8 at a time using AVX-512 gather */
436- for (; i + 8 <= count ; i += 8 ) {
437- __m256i idx = _mm256_loadu_si256 ((const __m256i * )(indices + i ));
438- __m512d result = _mm512_i32gather_pd (idx , dict , 8 );
439- _mm512_storeu_pd (output + i , result );
440- }
441-
442- /* Handle remaining */
443- for (; i < count ; i ++ ) {
444- output [i ] = dict [indices [i ]];
445- }
379+ /* Data movement doesn't care about type - reuse int64 implementation */
380+ carquet_avx512_gather_i64 ((const int64_t * )dict , indices , count , (int64_t * )output );
446381}
447382
448383/* ============================================================================
@@ -558,13 +493,8 @@ void carquet_avx512_unpack_bools(const uint8_t* input, uint8_t* output, int64_t
558493 uint64_t packed ;
559494 memcpy (& packed , input + byte_idx , 8 );
560495
561- /* Convert to mask */
562- __mmask64 mask = (__mmask64 )packed ;
563-
564- /* Create result: 1 where mask bit is set, 0 otherwise */
565- __m512i ones = _mm512_set1_epi8 (1 );
566- __m512i zeros = _mm512_setzero_si512 ();
567- __m512i result = _mm512_mask_mov_epi8 (zeros , mask , ones );
496+ /* Convert to mask and create result with maskz_set1 (1 where set, 0 otherwise) */
497+ __m512i result = _mm512_maskz_set1_epi8 ((__mmask64 )packed , 1 );
568498
569499 _mm512_storeu_si512 ((__m512i * )(output + i ), result );
570500 }
@@ -587,23 +517,30 @@ void carquet_avx512_pack_bools(const uint8_t* input, uint8_t* output, int64_t co
587517 for (; i + 64 <= count ; i += 64 ) {
588518 __m512i bools = _mm512_loadu_si512 ((const __m512i * )(input + i ));
589519
590- /* Compare with zero to get mask */
591- __mmask64 mask = _mm512_cmpneq_epi8_mask (bools , _mm512_setzero_si512 () );
520+ /* Use test_epi8_mask: bit is set if (a & b) != 0, i.e., if bool is non-zero */
521+ __mmask64 mask = _mm512_test_epi8_mask (bools , bools );
592522
593523 /* Store mask as 8 bytes */
594524 uint64_t packed = (uint64_t )mask ;
595525 memcpy (output + i / 8 , & packed , 8 );
596526 }
597527
598- /* Handle remaining */
599- for (; i < count ; i += 8 ) {
600- uint8_t byte = 0 ;
601- for (int64_t j = 0 ; j < 8 && i + j < count ; j ++ ) {
602- if (input [i + j ]) {
603- byte |= (1 << j );
604- }
605- }
606- output [i / 8 ] = byte ;
528+ /* Handle remaining elements with masked load */
529+ if (i < count ) {
530+ int64_t remaining = count - i ;
531+ /* Create mask for remaining elements: set bits 0..(remaining-1) */
532+ __mmask64 load_mask = (remaining >= 64 ) ? ~0ULL : ((1ULL << remaining ) - 1 );
533+
534+ /* Masked load zeros out elements beyond the mask */
535+ __m512i bools = _mm512_maskz_loadu_epi8 (load_mask , input + i );
536+
537+ /* Test for non-zero values */
538+ __mmask64 result_mask = _mm512_test_epi8_mask (bools , bools );
539+
540+ /* Write only the bytes we need */
541+ int64_t bytes_to_write = (remaining + 7 ) / 8 ;
542+ uint64_t packed = (uint64_t )result_mask ;
543+ memcpy (output + i / 8 , & packed , (size_t )bytes_to_write );
607544 }
608545}
609546
0 commit comments