@@ -78,6 +78,13 @@ static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(ggml_half) + QK4_NL * 2, "wro
7878
7979#define UNUSED GGML_UNUSED
8080
81+ static inline int nearest_int (float fval) {
82+ assert (fabsf (fval) <= 4194303 .f );
83+ float val = fval + 12582912 .f ;
84+ int i; memcpy (&i, &val, sizeof (int ));
85+ return (i & 0x007fffff ) - 0x00400000 ;
86+ }
87+
8188// Functions to create the interleaved data layout formats
8289
8390// interleave 4 block_q4_0s in blocks of blck_size_interleave
@@ -559,6 +566,218 @@ static void quantize_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRIC
559566
560567 block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy;
561568
569+ #if defined(__AVX2__)
570+ float iscale[4 ];
571+ __m256 srcv[4 ][32 ];
572+ __m256 iscale_vec[4 ];
573+
574+ for (int i = 0 ; i < nb; i++) {
575+ for (int row_iter = 0 ; row_iter < 4 ; row_iter++) {
576+ // Load elements into 4 AVX vectors
577+ __m256 v0 = _mm256_loadu_ps ( x + row_iter * k + i * 256 );
578+ __m256 v1 = _mm256_loadu_ps ( x + row_iter * k + i * 256 + 8 );
579+ __m256 v2 = _mm256_loadu_ps ( x + row_iter * k + i * 256 + 16 );
580+ __m256 v3 = _mm256_loadu_ps ( x + row_iter * k + i * 256 + 24 );
581+
582+ // Compute max(abs(e)) for the block
583+ const __m256 signBit = _mm256_set1_ps ( -0 .0f );
584+ __m256 abs0 = _mm256_andnot_ps ( signBit, v0 );
585+ __m256 abs1 = _mm256_andnot_ps ( signBit, v1 );
586+ __m256 abs2 = _mm256_andnot_ps ( signBit, v2 );
587+ __m256 abs3 = _mm256_andnot_ps ( signBit, v3 );
588+
589+ __m256 maxAbs = _mm256_max_ps ( abs0, abs1 );
590+ maxAbs = _mm256_max_ps ( maxAbs, abs2 );
591+ maxAbs = _mm256_max_ps ( maxAbs, abs3 );
592+
593+ __m256 mask0 = _mm256_cmp_ps ( maxAbs, v0, _CMP_EQ_OQ );
594+ __m256 mask1 = _mm256_cmp_ps ( maxAbs, v1, _CMP_EQ_OQ );
595+ __m256 mask2 = _mm256_cmp_ps ( maxAbs, v2, _CMP_EQ_OQ );
596+ __m256 mask3 = _mm256_cmp_ps ( maxAbs, v3, _CMP_EQ_OQ );
597+
598+ __m256 maskAbs = _mm256_or_ps (_mm256_or_ps (mask0, mask1),_mm256_or_ps (mask2, mask3));
599+
600+ srcv[row_iter][0 ] = v0;
601+ srcv[row_iter][1 ] = v1;
602+ srcv[row_iter][2 ] = v2;
603+ srcv[row_iter][3 ] = v3;
604+
605+ for (int sb = 1 ; sb < 8 ; sb++) {
606+ // Temporarily stores absolute quant values
607+ __m256 tempAbs = maxAbs;
608+
609+ // Load elements into 4 AVX vectors
610+ __m256 v0 = _mm256_loadu_ps ( x + row_iter * k + i * 256 + sb * 32 );
611+ __m256 v1 = _mm256_loadu_ps ( x + row_iter * k + i * 256 + sb * 32 + 8 );
612+ __m256 v2 = _mm256_loadu_ps ( x + row_iter * k + i * 256 + sb * 32 + 16 );
613+ __m256 v3 = _mm256_loadu_ps ( x + row_iter * k + i * 256 + sb * 32 + 24 );
614+
615+ // Compute max(abs(e)) for the block
616+ __m256 abs0 = _mm256_andnot_ps ( signBit, v0 );
617+ __m256 abs1 = _mm256_andnot_ps ( signBit, v1 );
618+ __m256 abs2 = _mm256_andnot_ps ( signBit, v2 );
619+ __m256 abs3 = _mm256_andnot_ps ( signBit, v3 );
620+
621+ maxAbs = _mm256_max_ps ( maxAbs, abs0 );
622+ maxAbs = _mm256_max_ps ( maxAbs, abs1 );
623+ maxAbs = _mm256_max_ps ( maxAbs, abs2 );
624+ maxAbs = _mm256_max_ps ( maxAbs, abs3 );
625+
626+ __m256 mask_prev = _mm256_cmp_ps ( tempAbs, maxAbs, _CMP_EQ_OQ );
627+ maskAbs = _mm256_and_ps ( maskAbs, mask_prev );
628+
629+ mask0 = _mm256_cmp_ps ( maxAbs, v0, _CMP_EQ_OQ );
630+ mask1 = _mm256_cmp_ps ( maxAbs, v1, _CMP_EQ_OQ );
631+ mask2 = _mm256_cmp_ps ( maxAbs, v2, _CMP_EQ_OQ );
632+ mask3 = _mm256_cmp_ps ( maxAbs, v3, _CMP_EQ_OQ );
633+
634+ __m256 mask_curr = _mm256_or_ps (_mm256_or_ps (mask0, mask1),_mm256_or_ps (mask2, mask3));
635+ maskAbs = _mm256_or_ps (maskAbs, mask_curr);
636+
637+ srcv[row_iter][sb * 4 ] = v0;
638+ srcv[row_iter][sb * 4 + 1 ] = v1;
639+ srcv[row_iter][sb * 4 + 2 ] = v2;
640+ srcv[row_iter][sb * 4 + 3 ] = v3;
641+ }
642+
643+ __m128 max4 = _mm_max_ps ( _mm256_extractf128_ps ( maxAbs, 1 ), _mm256_castps256_ps128 ( maxAbs ) );
644+ max4 = _mm_max_ps ( max4, _mm_movehl_ps ( max4, max4 ) );
645+ max4 = _mm_max_ss ( max4, _mm_movehdup_ps ( max4 ) );
646+ const float maxScalar = _mm_cvtss_f32 ( max4 );
647+
648+ __m256 maxScalarVec = _mm256_set1_ps (maxScalar);
649+
650+ __m256 mask_next = _mm256_cmp_ps ( maxScalarVec, maxAbs, _CMP_EQ_OQ );
651+ __m256 finalMask = _mm256_and_ps (maskAbs, mask_next);
652+
653+ const int mask = _mm256_movemask_ps (finalMask);
654+ iscale[row_iter] = ( maxScalar != 0 .0f ) ? 127 .f / maxScalar : 0 .0f ;
655+
656+ if (mask) {
657+ iscale[row_iter] = ( maxScalar != 0 .0f ) ? -127 .f / maxScalar: 0 .0f ;
658+ }
659+
660+ y[i].d [row_iter] = maxScalar ? 1 /iscale[row_iter] : 0 ;
661+ iscale_vec[row_iter] = _mm256_set1_ps (iscale[row_iter]);
662+ }
663+
664+ __m256i quants_interleaved[32 ];
665+ for (int j = 0 ; j < 32 ; j++) {
666+ // Apply the multiplier
667+ __m256 v0 = _mm256_mul_ps (srcv[0 ][j], iscale_vec[0 ]);
668+ __m256 v1 = _mm256_mul_ps (srcv[1 ][j], iscale_vec[1 ]);
669+ __m256 v2 = _mm256_mul_ps (srcv[2 ][j], iscale_vec[2 ]);
670+ __m256 v3 = _mm256_mul_ps (srcv[3 ][j], iscale_vec[3 ]);
671+
672+ // Round to nearest integer
673+ v0 = _mm256_round_ps ( v0, _MM_ROUND_NEAREST );
674+ v1 = _mm256_round_ps ( v1, _MM_ROUND_NEAREST );
675+ v2 = _mm256_round_ps ( v2, _MM_ROUND_NEAREST );
676+ v3 = _mm256_round_ps ( v3, _MM_ROUND_NEAREST );
677+
678+ // Convert floats to integers
679+ __m256i i0 = _mm256_cvtps_epi32 ( v0 );
680+ __m256i i1 = _mm256_cvtps_epi32 ( v1 );
681+ __m256i i2 = _mm256_cvtps_epi32 ( v2 );
682+ __m256i i3 = _mm256_cvtps_epi32 ( v3 );
683+
684+ // Convert int32 to int16
685+ i0 = _mm256_packs_epi32 ( i0, i1 );
686+ i2 = _mm256_packs_epi32 ( i2, i3 );
687+ // Convert int16 to int8
688+ i0 = _mm256_packs_epi16 ( i0, i2 );
689+
690+ // Permute and store the quantized weights in the required order after the pack instruction
691+ const __m256i perm = _mm256_setr_epi32 ( 0 , 4 , 1 , 5 , 2 , 6 , 3 , 7 );
692+ i0 = _mm256_permutevar8x32_epi32 ( i0, perm );
693+
694+ _mm256_storeu_si256 ((__m256i *)(y[i].qs + 32 * j), i0);
695+ quants_interleaved[j] = i0;
696+ }
697+
698+ // Masks to shuffle the quants of corresonding sub blocks for rearraning quants for vectorized bsums computation
699+ __m256i shuffle_mask_sb2 = _mm256_castsi128_si256 (_mm_setr_epi8 (0 , 1 , 0 , 1 , 4 , 5 , 6 , 7 , 8 , 9 , 8 , 9 , 12 , 13 , 14 , 15 ));
700+ shuffle_mask_sb2 = _mm256_permute2f128_si256 (shuffle_mask_sb2, shuffle_mask_sb2, 0 );
701+ __m256i shuffle_mask_sb3 = _mm256_castsi128_si256 (_mm_setr_epi8 (0 , 1 , 2 , 3 , 0 , 1 , 6 , 7 , 8 , 9 , 10 , 11 , 8 , 9 , 14 , 15 ));
702+ shuffle_mask_sb3 = _mm256_permute2f128_si256 (shuffle_mask_sb3, shuffle_mask_sb3, 0 );
703+ __m256i shuffle_mask_sb4 = _mm256_castsi128_si256 (_mm_setr_epi8 (0 , 1 , 2 , 3 , 4 , 5 , 0 , 1 , 8 , 9 , 10 , 11 , 12 , 13 , 8 , 9 ));
704+ shuffle_mask_sb4 = _mm256_permute2f128_si256 (shuffle_mask_sb4, shuffle_mask_sb4, 0 );
705+
706+ for (int k = 0 ; k < 4 ; k++) {
707+ // Quants from four different sub blocks are taken
708+ __m256i q0 = quants_interleaved[k * 8 + 0 ];
709+ __m256i q1 = quants_interleaved[k * 8 + 1 ];
710+ __m256i q2 = quants_interleaved[k * 8 + 2 ];
711+ __m256i q3 = quants_interleaved[k * 8 + 3 ];
712+ __m256i q4 = quants_interleaved[k * 8 + 4 ];
713+ __m256i q5 = quants_interleaved[k * 8 + 5 ];
714+ __m256i q6 = quants_interleaved[k * 8 + 6 ];
715+ __m256i q7 = quants_interleaved[k * 8 + 7 ];
716+
717+
718+ // The below code block has the first half of different sub blocks shuffled and blended so as to process 2 values from each sub block at a time
719+ __m256i sb2_h1_shuffled = _mm256_shuffle_epi8 (q2, shuffle_mask_sb2);
720+ __m256i sb_h1_interleaved = _mm256_blend_epi16 (q0, sb2_h1_shuffled, 34 );
721+ __m256i sb3_h1_shuffled = _mm256_shuffle_epi8 (q4, shuffle_mask_sb3);
722+ sb_h1_interleaved = _mm256_blend_epi16 (sb_h1_interleaved, sb3_h1_shuffled, 68 );
723+ __m256i sb4_h1_shuffled = _mm256_shuffle_epi8 (q6, shuffle_mask_sb4);
724+ sb_h1_interleaved = _mm256_blend_epi16 (sb_h1_interleaved, sb4_h1_shuffled, 136 );
725+
726+ __m256i one = _mm256_set1_epi8 (1 );
727+ __m256i bsums_r1 = _mm256_maddubs_epi16 (one, sb_h1_interleaved);
728+
729+ for (int l = 0 ; l < 3 ; l++) {
730+ // Quants value shifted to process next two values from each sub block
731+ q0 = _mm256_srli_epi64 (q0, 16 );
732+ q2 = _mm256_srli_epi64 (q2, 16 );
733+ q4 = _mm256_srli_epi64 (q4, 16 );
734+ q6 = _mm256_srli_epi64 (q6, 16 );
735+
736+ sb2_h1_shuffled = _mm256_shuffle_epi8 (q2, shuffle_mask_sb2);
737+ sb_h1_interleaved = _mm256_blend_epi16 (q0, sb2_h1_shuffled, 34 );
738+ sb3_h1_shuffled = _mm256_shuffle_epi8 (q4, shuffle_mask_sb3);
739+ sb_h1_interleaved = _mm256_blend_epi16 (sb_h1_interleaved, sb3_h1_shuffled, 68 );
740+ sb4_h1_shuffled = _mm256_shuffle_epi8 (q6, shuffle_mask_sb4);
741+ sb_h1_interleaved = _mm256_blend_epi16 (sb_h1_interleaved, sb4_h1_shuffled, 136 );
742+
743+ bsums_r1 = _mm256_add_epi16 (bsums_r1, _mm256_maddubs_epi16 (one, sb_h1_interleaved));
744+ }
745+
746+ // The below code block has the second half of different sub blocks shuffled and blended so as to process 2 values from each sub block at a time
747+ __m256i sb2_h2_shuffled = _mm256_shuffle_epi8 (q3, shuffle_mask_sb2);
748+ __m256i sb_h2_interleaved = _mm256_blend_epi16 (q1, sb2_h2_shuffled, 34 );
749+ __m256i sb3_h2_shuffled = _mm256_shuffle_epi8 (q5, shuffle_mask_sb3);
750+ sb_h2_interleaved = _mm256_blend_epi16 (sb_h2_interleaved, sb3_h2_shuffled, 68 );
751+ __m256i sb4_h2_shuffled = _mm256_shuffle_epi8 (q7, shuffle_mask_sb4);
752+ sb_h2_interleaved = _mm256_blend_epi16 (sb_h2_interleaved, sb4_h2_shuffled, 136 );
753+
754+ __m256i bsums_r2 = _mm256_maddubs_epi16 (one, sb_h2_interleaved);
755+
756+ for (int l = 0 ; l < 3 ; l++) {
757+ // Quants value shifted to process next two values from each sub block
758+ q1 = _mm256_srli_epi64 (q1, 16 );
759+ q3 = _mm256_srli_epi64 (q3, 16 );
760+ q5 = _mm256_srli_epi64 (q5, 16 );
761+ q7 = _mm256_srli_epi64 (q7, 16 );
762+
763+ sb2_h2_shuffled = _mm256_shuffle_epi8 (q3, shuffle_mask_sb2);
764+ sb_h2_interleaved = _mm256_blend_epi16 (q1, sb2_h2_shuffled, 34 );
765+ sb3_h2_shuffled = _mm256_shuffle_epi8 (q5, shuffle_mask_sb3);
766+ sb_h2_interleaved = _mm256_blend_epi16 (sb_h2_interleaved, sb3_h2_shuffled, 68 );
767+ sb4_h2_shuffled = _mm256_shuffle_epi8 (q7, shuffle_mask_sb4);
768+ sb_h2_interleaved = _mm256_blend_epi16 (sb_h2_interleaved, sb4_h2_shuffled, 136 );
769+
770+ bsums_r2 = _mm256_add_epi16 (bsums_r2, _mm256_maddubs_epi16 (one, sb_h2_interleaved));
771+ }
772+
773+ // Overall bsums in interleaved fashion computed by adding results of both halves
774+ __m256i bsums_r = _mm256_add_epi16 (bsums_r1, bsums_r2);
775+ _mm256_storeu_si256 ((__m256i *)(y[i].bsums + 16 * k), bsums_r);
776+ }
777+ }
778+
779+ #else
780+
562781 // scalar
563782 const int blck_size_interleave = 8 ;
564783 float srcv[4 ][QK_K];
@@ -597,10 +816,11 @@ static void quantize_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRIC
597816 int index = (((j & 31 ) >> 3 ) << 2 ) + ((j >> 8 ) << 4 ) + ((j >> 6 ) & 3 );
598817
599818 float x0 = srcv[src_id][src_offset] * iscale[src_id];
600- y[i].qs [j] = roundf (x0);
819+ y[i].qs [j] = nearest_int (x0);
601820 y[i].bsums [index] += y[i].qs [j];
602821 }
603822 }
823+ #endif
604824}
605825
606826static void quantize_mat_q8_0 (const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) {
0 commit comments