Skip to content

Commit cba0df3

Browse files
committed
Add vector version of quantize_q8_K_4x8 function
1 parent 022ad35 commit cba0df3

File tree

1 file changed

+221
-1
lines changed

1 file changed

+221
-1
lines changed

ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp

Lines changed: 221 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,13 @@ static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(ggml_half) + QK4_NL * 2, "wro
7878

7979
#define UNUSED GGML_UNUSED
8080

81+
static inline int nearest_int(float fval) {
82+
assert(fabsf(fval) <= 4194303.f);
83+
float val = fval + 12582912.f;
84+
int i; memcpy(&i, &val, sizeof(int));
85+
return (i & 0x007fffff) - 0x00400000;
86+
}
87+
8188
// Functions to create the interleaved data layout formats
8289

8390
// interleave 4 block_q4_0s in blocks of blck_size_interleave
@@ -559,6 +566,218 @@ static void quantize_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRIC
559566

560567
block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy;
561568

569+
#if defined(__AVX2__)
570+
float iscale[4];
571+
__m256 srcv[4][32];
572+
__m256 iscale_vec[4];
573+
574+
for (int i = 0; i < nb; i++) {
575+
for (int row_iter = 0; row_iter < 4; row_iter++) {
576+
// Load elements into 4 AVX vectors
577+
__m256 v0 = _mm256_loadu_ps( x + row_iter * k + i * 256 );
578+
__m256 v1 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 8 );
579+
__m256 v2 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 16 );
580+
__m256 v3 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 24 );
581+
582+
// Compute max(abs(e)) for the block
583+
const __m256 signBit = _mm256_set1_ps( -0.0f );
584+
__m256 abs0 = _mm256_andnot_ps( signBit, v0 );
585+
__m256 abs1 = _mm256_andnot_ps( signBit, v1 );
586+
__m256 abs2 = _mm256_andnot_ps( signBit, v2 );
587+
__m256 abs3 = _mm256_andnot_ps( signBit, v3 );
588+
589+
__m256 maxAbs = _mm256_max_ps( abs0, abs1 );
590+
maxAbs = _mm256_max_ps( maxAbs, abs2 );
591+
maxAbs = _mm256_max_ps( maxAbs, abs3 );
592+
593+
__m256 mask0 = _mm256_cmp_ps( maxAbs, v0, _CMP_EQ_OQ );
594+
__m256 mask1 = _mm256_cmp_ps( maxAbs, v1, _CMP_EQ_OQ );
595+
__m256 mask2 = _mm256_cmp_ps( maxAbs, v2, _CMP_EQ_OQ );
596+
__m256 mask3 = _mm256_cmp_ps( maxAbs, v3, _CMP_EQ_OQ );
597+
598+
__m256 maskAbs = _mm256_or_ps(_mm256_or_ps(mask0, mask1),_mm256_or_ps(mask2, mask3));
599+
600+
srcv[row_iter][0] = v0;
601+
srcv[row_iter][1] = v1;
602+
srcv[row_iter][2] = v2;
603+
srcv[row_iter][3] = v3;
604+
605+
for (int sb = 1; sb < 8; sb++) {
606+
// Temporarily stores absolute quant values
607+
__m256 tempAbs = maxAbs;
608+
609+
// Load elements into 4 AVX vectors
610+
__m256 v0 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32);
611+
__m256 v1 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 8 );
612+
__m256 v2 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 16 );
613+
__m256 v3 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 24 );
614+
615+
// Compute max(abs(e)) for the block
616+
__m256 abs0 = _mm256_andnot_ps( signBit, v0 );
617+
__m256 abs1 = _mm256_andnot_ps( signBit, v1 );
618+
__m256 abs2 = _mm256_andnot_ps( signBit, v2 );
619+
__m256 abs3 = _mm256_andnot_ps( signBit, v3 );
620+
621+
maxAbs = _mm256_max_ps( maxAbs, abs0 );
622+
maxAbs = _mm256_max_ps( maxAbs, abs1 );
623+
maxAbs = _mm256_max_ps( maxAbs, abs2 );
624+
maxAbs = _mm256_max_ps( maxAbs, abs3 );
625+
626+
__m256 mask_prev = _mm256_cmp_ps( tempAbs, maxAbs, _CMP_EQ_OQ );
627+
maskAbs = _mm256_and_ps( maskAbs, mask_prev );
628+
629+
mask0 = _mm256_cmp_ps( maxAbs, v0, _CMP_EQ_OQ );
630+
mask1 = _mm256_cmp_ps( maxAbs, v1, _CMP_EQ_OQ );
631+
mask2 = _mm256_cmp_ps( maxAbs, v2, _CMP_EQ_OQ );
632+
mask3 = _mm256_cmp_ps( maxAbs, v3, _CMP_EQ_OQ );
633+
634+
__m256 mask_curr = _mm256_or_ps(_mm256_or_ps(mask0, mask1),_mm256_or_ps(mask2, mask3));
635+
maskAbs = _mm256_or_ps(maskAbs, mask_curr);
636+
637+
srcv[row_iter][sb * 4] = v0;
638+
srcv[row_iter][sb * 4 + 1] = v1;
639+
srcv[row_iter][sb * 4 + 2] = v2;
640+
srcv[row_iter][sb * 4 + 3] = v3;
641+
}
642+
643+
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
644+
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
645+
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
646+
const float maxScalar = _mm_cvtss_f32( max4 );
647+
648+
__m256 maxScalarVec = _mm256_set1_ps(maxScalar);
649+
650+
__m256 mask_next = _mm256_cmp_ps( maxScalarVec, maxAbs, _CMP_EQ_OQ );
651+
__m256 finalMask = _mm256_and_ps(maskAbs, mask_next);
652+
653+
const int mask = _mm256_movemask_ps(finalMask);
654+
iscale[row_iter] = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
655+
656+
if(mask) {
657+
iscale[row_iter] = ( maxScalar != 0.0f ) ? -127.f / maxScalar: 0.0f;
658+
}
659+
660+
y[i].d[row_iter] = maxScalar ? 1/iscale[row_iter] : 0;
661+
iscale_vec[row_iter] = _mm256_set1_ps(iscale[row_iter]);
662+
}
663+
664+
__m256i quants_interleaved[32];
665+
for (int j = 0; j < 32; j++) {
666+
// Apply the multiplier
667+
__m256 v0 = _mm256_mul_ps(srcv[0][j], iscale_vec[0]);
668+
__m256 v1 = _mm256_mul_ps(srcv[1][j], iscale_vec[1]);
669+
__m256 v2 = _mm256_mul_ps(srcv[2][j], iscale_vec[2]);
670+
__m256 v3 = _mm256_mul_ps(srcv[3][j], iscale_vec[3]);
671+
672+
// Round to nearest integer
673+
v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
674+
v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
675+
v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
676+
v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
677+
678+
// Convert floats to integers
679+
__m256i i0 = _mm256_cvtps_epi32( v0 );
680+
__m256i i1 = _mm256_cvtps_epi32( v1 );
681+
__m256i i2 = _mm256_cvtps_epi32( v2 );
682+
__m256i i3 = _mm256_cvtps_epi32( v3 );
683+
684+
// Convert int32 to int16
685+
i0 = _mm256_packs_epi32( i0, i1 );
686+
i2 = _mm256_packs_epi32( i2, i3 );
687+
// Convert int16 to int8
688+
i0 = _mm256_packs_epi16( i0, i2 );
689+
690+
// Permute and store the quantized weights in the required order after the pack instruction
691+
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
692+
i0 = _mm256_permutevar8x32_epi32( i0, perm );
693+
694+
_mm256_storeu_si256((__m256i *)(y[i].qs + 32 * j), i0);
695+
quants_interleaved[j] = i0;
696+
}
697+
698+
// Masks to shuffle the quants of corresonding sub blocks for rearraning quants for vectorized bsums computation
699+
__m256i shuffle_mask_sb2 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 0, 1, 4, 5, 6, 7, 8, 9, 8, 9, 12, 13, 14, 15));
700+
shuffle_mask_sb2 = _mm256_permute2f128_si256(shuffle_mask_sb2, shuffle_mask_sb2, 0);
701+
__m256i shuffle_mask_sb3 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 2, 3, 0, 1, 6, 7, 8, 9, 10, 11, 8, 9, 14, 15));
702+
shuffle_mask_sb3 = _mm256_permute2f128_si256(shuffle_mask_sb3, shuffle_mask_sb3, 0);
703+
__m256i shuffle_mask_sb4 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 2, 3, 4, 5, 0, 1, 8, 9, 10, 11, 12, 13, 8, 9));
704+
shuffle_mask_sb4 = _mm256_permute2f128_si256(shuffle_mask_sb4, shuffle_mask_sb4, 0);
705+
706+
for (int k = 0; k < 4; k++) {
707+
// Quants from four different sub blocks are taken
708+
__m256i q0 = quants_interleaved[k * 8 + 0];
709+
__m256i q1 = quants_interleaved[k * 8 + 1];
710+
__m256i q2 = quants_interleaved[k * 8 + 2];
711+
__m256i q3 = quants_interleaved[k * 8 + 3];
712+
__m256i q4 = quants_interleaved[k * 8 + 4];
713+
__m256i q5 = quants_interleaved[k * 8 + 5];
714+
__m256i q6 = quants_interleaved[k * 8 + 6];
715+
__m256i q7 = quants_interleaved[k * 8 + 7];
716+
717+
718+
// The below code block has the first half of different sub blocks shuffled and blended so as to process 2 values from each sub block at a time
719+
__m256i sb2_h1_shuffled = _mm256_shuffle_epi8(q2, shuffle_mask_sb2);
720+
__m256i sb_h1_interleaved = _mm256_blend_epi16(q0, sb2_h1_shuffled, 34);
721+
__m256i sb3_h1_shuffled = _mm256_shuffle_epi8(q4, shuffle_mask_sb3);
722+
sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb3_h1_shuffled, 68);
723+
__m256i sb4_h1_shuffled = _mm256_shuffle_epi8(q6, shuffle_mask_sb4);
724+
sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb4_h1_shuffled, 136);
725+
726+
__m256i one = _mm256_set1_epi8(1);
727+
__m256i bsums_r1 = _mm256_maddubs_epi16(one, sb_h1_interleaved);
728+
729+
for(int l = 0; l < 3; l++) {
730+
// Quants value shifted to process next two values from each sub block
731+
q0 = _mm256_srli_epi64(q0, 16);
732+
q2 = _mm256_srli_epi64(q2, 16);
733+
q4 = _mm256_srli_epi64(q4, 16);
734+
q6 = _mm256_srli_epi64(q6, 16);
735+
736+
sb2_h1_shuffled = _mm256_shuffle_epi8(q2, shuffle_mask_sb2);
737+
sb_h1_interleaved = _mm256_blend_epi16(q0, sb2_h1_shuffled, 34);
738+
sb3_h1_shuffled = _mm256_shuffle_epi8(q4, shuffle_mask_sb3);
739+
sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb3_h1_shuffled, 68);
740+
sb4_h1_shuffled = _mm256_shuffle_epi8(q6, shuffle_mask_sb4);
741+
sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb4_h1_shuffled, 136);
742+
743+
bsums_r1 = _mm256_add_epi16(bsums_r1, _mm256_maddubs_epi16(one, sb_h1_interleaved));
744+
}
745+
746+
// The below code block has the second half of different sub blocks shuffled and blended so as to process 2 values from each sub block at a time
747+
__m256i sb2_h2_shuffled = _mm256_shuffle_epi8(q3, shuffle_mask_sb2);
748+
__m256i sb_h2_interleaved = _mm256_blend_epi16(q1, sb2_h2_shuffled, 34);
749+
__m256i sb3_h2_shuffled = _mm256_shuffle_epi8(q5, shuffle_mask_sb3);
750+
sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb3_h2_shuffled, 68);
751+
__m256i sb4_h2_shuffled = _mm256_shuffle_epi8(q7, shuffle_mask_sb4);
752+
sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb4_h2_shuffled, 136);
753+
754+
__m256i bsums_r2 = _mm256_maddubs_epi16(one, sb_h2_interleaved);
755+
756+
for(int l = 0; l < 3; l++) {
757+
// Quants value shifted to process next two values from each sub block
758+
q1 = _mm256_srli_epi64(q1, 16);
759+
q3 = _mm256_srli_epi64(q3, 16);
760+
q5 = _mm256_srli_epi64(q5, 16);
761+
q7 = _mm256_srli_epi64(q7, 16);
762+
763+
sb2_h2_shuffled = _mm256_shuffle_epi8(q3, shuffle_mask_sb2);
764+
sb_h2_interleaved = _mm256_blend_epi16(q1, sb2_h2_shuffled, 34);
765+
sb3_h2_shuffled = _mm256_shuffle_epi8(q5, shuffle_mask_sb3);
766+
sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb3_h2_shuffled, 68);
767+
sb4_h2_shuffled = _mm256_shuffle_epi8(q7, shuffle_mask_sb4);
768+
sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb4_h2_shuffled, 136);
769+
770+
bsums_r2 = _mm256_add_epi16(bsums_r2, _mm256_maddubs_epi16(one, sb_h2_interleaved));
771+
}
772+
773+
// Overall bsums in interleaved fashion computed by adding results of both halves
774+
__m256i bsums_r = _mm256_add_epi16(bsums_r1, bsums_r2);
775+
_mm256_storeu_si256((__m256i *)(y[i].bsums + 16 * k), bsums_r);
776+
}
777+
}
778+
779+
#else
780+
562781
// scalar
563782
const int blck_size_interleave = 8;
564783
float srcv[4][QK_K];
@@ -597,10 +816,11 @@ static void quantize_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRIC
597816
int index = (((j & 31) >> 3) << 2) + ((j >> 8) << 4) + ((j >> 6) & 3);
598817

599818
float x0 = srcv[src_id][src_offset] * iscale[src_id];
600-
y[i].qs[j] = roundf(x0);
819+
y[i].qs[j] = nearest_int(x0);
601820
y[i].bsums[index] += y[i].qs[j];
602821
}
603822
}
823+
#endif
604824
}
605825

606826
static void quantize_mat_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) {

0 commit comments

Comments
 (0)