@@ -660,97 +660,148 @@ static inline int compare_fractions_desc(const void * a, const void * b) {
660660
661661// exhaustive search with cumulative sums
662662// Need Faux to have room for n*(max(abs(nmin), abs(nmax))) fractions
663- static float make_qkxs_quants (int n , int nmin , int nmax , const float * restrict x , const float * restrict weights , int8_t * restrict L , struct fraction * restrict Faux , bool signed_scale ) {
664- float max = 0.0f ;
665- float amax = 0.0f ;
666- for (int i = 0 ; i < n ; ++ i ) {
667- float ax = fabsf (x [i ]);
668- if (ax > amax ) {
669- amax = ax ;
670- max = x [i ];
663+ static float make_qkxs_quants (int n , int nmin , int nmax , const float * restrict x , const float * restrict weights , int8_t * restrict L , int8_t * restrict Laux , struct fraction * restrict Faux , bool signed_scale ) {
664+ const int orig_nmin = nmin ;
665+ const int orig_nmax = nmax ;
666+ float max = x [0 ];
667+ float min = x [0 ];
668+ float w_amax = weights [0 ] * fabsf (x [0 ]);
669+ int max_i = 0 ;
670+ int w_amax_i = 0 ;
671+ int min_i = 0 ;
672+ for (int i = 1 ; i < n ; ++ i ) {
673+ if (x [i ] < min ) { min = x [i ]; min_i = i ; }
674+ if (x [i ] > max ) { max = x [i ]; max_i = i ; }
675+ // Find the most important value
676+ const float w = weights [i ];
677+ const float wax = w * fabsf (x [i ]);
678+ if (wax > w_amax ) {
679+ w_amax = wax ;
680+ w_amax_i = i ;
681+ }
682+ }
683+ const int amax_i = fabsf (min ) > fabsf (max ) ? min_i : max_i ;
684+ const float amax = fabsf (x [amax_i ]);
685+
686+ if (amax < GROUP_MAX_EPS ) { // all zero
687+ for (int i = 0 ; i < n ; ++ i ) {
688+ L [i ] = 0 ;
671689 }
690+ return 0.0f ;
672691 }
673692 bool negative_scale = false;
674693 if (signed_scale && - nmin != nmax ) {
675694 // the max side should have the biggest range
676- if ((max < 0.0f ) == (- nmin < nmax )) {
695+ // FIXME: this is incorrect when the weights[.] do not sort in the same order as fabsf(x[.])
696+ // or is it some other condition?
697+ if ((x [amax_i ] < 0.0f ) == (- nmin < nmax )) {
677698 // [-4, 3] ==> [-3, 4]
678- int tmp = nmin ;
699+ const int tmp = nmin ;
700+ const float ftmp = min ;
679701 nmin = - nmax ;
680702 nmax = - tmp ;
703+ min = - max ;
704+ max = - ftmp ;
681705 negative_scale = true;
682706 }
683707 }
684- if (amax < GROUP_MAX_EPS ) { // all zero
708+
709+ // Find the max range in [0, amax_range] which doesn't result in clamping.
710+ // This is the range from the side which would clamp first (biggest ratio of max to nmax).
711+ int amax_range ;
712+ float range_max ;
713+ if (fabsf (- max * nmin ) < fabsf (- min * nmax )) {
714+ amax_range = MAX (0 , - nmin );
715+ range_max = fabsf (min );
716+ } else {
717+ amax_range = MAX (0 , nmax );
718+ range_max = fabsf (max );
719+ }
720+ float sumlx = 0.0f ;
721+ float suml2 = 0.0f ;
722+ float scale = 0.0f ;
723+ float best = 0.0f ;
724+ float best_denom = 1.0f ;
725+ if (amax_range > 1 ) {
726+ // The smallest non-redundant iscale makes the first clamped value half+1 its max integer value.
727+ // Proof: anything smaller has a representable vector with values twice as big.
728+ const float iscale = ((float )(amax_range / 2 + 1 ))/range_max * (negative_scale ? -1.0f : 1.0f );
729+ for (int i = 0 ; i < n ; ++ i ) {
730+ const float w = weights [i ];
731+ int l = MAX (nmin , MIN (lroundf (x [i ] * iscale ), nmax ));
732+ if (negative_scale ) { l = - l ; }
733+ Laux [i ] = l ;
734+ L [i ] = l ;
735+ suml2 += w * l * l ;
736+ sumlx += w * l * x [i ];
737+ }
738+ best = sumlx * sumlx ;
739+ best_denom = suml2 ; // should never be zero
740+ scale = sumlx / suml2 ;
741+ } else {
685742 for (int i = 0 ; i < n ; ++ i ) {
743+ Laux [i ] = 0 ;
686744 L [i ] = 0 ;
687745 }
688- return 0.0f ;
689746 }
747+
748+ const int imax_range = MAX (0 , (x [w_amax_i ] < 0.0f ) ? - nmin : nmax );
749+ const int max_odd = 2 * (imax_range + 1 ) + 1 ;
750+ const float wmax = fabsf (x [w_amax_i ]);
690751 int n_frac = 0 ;
691752 for (int i = 0 ; i < n ; ++ i ) {
692753 // assuming nmin <= nmax
693- const int odd_max = MAX (0 , x [i ] < 0 ? - nmin : nmax );
694- const int odd_min = MAX (0 , x [i ] < 0 ? - nmax : nmin );
754+ const int odd_max = MAX (abs ( Laux [ i ]) , x [i ] < 0.0f ? - nmin : nmax );
755+ const int odd_min = MAX (abs ( Laux [ i ]) , x [i ] < 0.0f ? - nmax : nmin );
695756 const float v = fabsf (x [i ]);
696- // fprintf(stderr, "%s: i=%d, odd_min=%d, odd_max=%d\n", __func__, i, odd_min, odd_max) ;
757+ const float v_max_odd = v * max_odd ;
697758 for (int j = odd_min ; j < odd_max ; ++ j ) {
698759 const float odd = 2 * j + 1 ;
699- Faux [n_frac ++ ] = (struct fraction ){
700- .numer = v ,
701- .denom = odd ,
702- .i = i ,
703- };
760+ if (wmax * odd < v_max_odd ) {
761+ Faux [n_frac ++ ] = (struct fraction ){
762+ .numer = v ,
763+ .denom = odd ,
764+ .i = i ,
765+ };
766+ } else {
767+ // stop when the inverse scale would result in clamping the max (FIXME: most important) value
768+ break ;
769+ }
704770 }
705771 }
706772
707773 qsort (Faux , n_frac , sizeof (struct fraction ), compare_fractions_desc );
708774
709- float iscale = 0.0f ;
710- {
711- float sumlx = 0.0f ;
712- float suml2 = 0.0f ;
713- float best = 0.0f ;
714- float best_denom = 1.0f ;
715- for (int i = 0 ; i < n_frac ; ++ i ) {
716- // maximize the weighted cosine
717- const int ii = Faux [i ].i ;
718- const float w = weights ? weights [ii ] : x [ii ] * x [ii ];
719- sumlx += w * Faux [i ].numer ;
720- suml2 += w * Faux [i ].denom ;
721- const float current = sumlx * sumlx ;
722- // fprintf(stderr, "%s: Faux[%d]=(%f/%f) * %f, square(sumlx)=%f, suml2=%f, k*cos2=%f\n", __func__, i, Faux[i].numer, Faux[i].denom, Faux[i].weight, current, suml2, current / suml2);
723- // use the last in case of equality
724- // FIXME: > or >= ?? Why does [0, 0, 1] rounds to [0, 0, 0] with >= ?
725- if (suml2 > 0.0f && current * best_denom > best * suml2 ) {
726- best = current ;
727- best_denom = suml2 ;
728- iscale = Faux [i ].numer > 0.0f ? Faux [i ].denom / (2.0f * Faux [i ].numer ) : 0.0f ;
729- if (!isfinite (iscale )) {
730- fprintf (stderr , "%s: iscale is not finite, %f/(2*%f)\n" , __func__ , Faux [i ].denom , Faux [i ].numer );
775+ int best_p_i = -1 ; // consecutive with 0..n_frac
776+ for (int i = 0 ; i < n_frac ; ++ i ) {
777+ // maximize the weighted cosine
778+ const int ii = Faux [i ].i ;
779+ const float w = weights ? weights [ii ] : x [ii ] * x [ii ];
780+ sumlx += w * Faux [i ].numer ;
781+ suml2 += w * Faux [i ].denom ;
782+ const float current = sumlx * sumlx ;
783+ Laux [ii ] += x [ii ] < 0.0f ? -1 : 1 ;
784+ if (suml2 > 0.0f && Faux [i ].numer > 0.0f && current * best_denom > best * suml2 ) {
785+ best = current ;
786+ best_denom = suml2 ;
787+ scale = sumlx / suml2 ;
788+ if (i == best_p_i + 1 ) {
789+ // reduce copies for consecutive bests
790+ L [ii ] += x [ii ] < 0.0f ? -1 : 1 ;
791+ } else {
792+ for (int j = 0 ; j < n ; ++ j ) {
793+ L [j ] = Laux [j ];
731794 }
732795 }
796+ best_p_i = i ;
733797 }
734798 }
735- // (very) small fudging necessary because floats otherwise round to nearest even
736- iscale = iscale * ((float )((1 << 23 ) + 1 ) / (float )(1 << 23 ));
737-
738- float sumlx = 0.0f ;
739- float suml2 = 0.0f ;
740799 for (int i = 0 ; i < n ; ++ i ) {
741- // Rounding away from zero is assumed by the search algorithm above.
742- int l = MAX (nmin , MIN (lroundf (x [i ] * iscale ), nmax ));
743- if (negative_scale ) {
744- l = - l ;
745- }
746- L [i ] = negative_scale ? l + nmax : l - nmin ;
747- float w = weights ? weights [i ] : x [i ] * x [i ];
748- // weighted projection scale
749- sumlx += w * x [i ] * l ;
750- suml2 += w * l * l ;
800+ L [i ] = negative_scale ? (- L [i ] + nmax ) : (L [i ] + - nmin );
801+ GGML_ASSERT (L [i ] >= 0 && L [i ] <= nmax - nmin );
751802 }
752803
753- return suml2 > 0.0f ? sumlx / suml2 : 0.0f ;
804+ return negative_scale ? - scale : scale ;
754805}
755806
756807// non-linear exhaustive search with cumulative sums
@@ -1234,6 +1285,7 @@ void quantize_row_q3_K_ref(const float * restrict x, block_q3_K * restrict y, in
12341285 const int nb = k / QK_K ;
12351286
12361287 int8_t L [QK_K ];
1288+ int8_t Laux [16 ];
12371289 struct fraction Faux [16 * 4 ];
12381290 float scales [QK_K / 16 ];
12391291 float weights [16 ];
@@ -1247,7 +1299,7 @@ void quantize_row_q3_K_ref(const float * restrict x, block_q3_K * restrict y, in
12471299 float max_scale = 0 ;
12481300 float amax = 0 ;
12491301 for (int j = 0 ; j < QK_K /16 ; ++ j ) {
1250- scales [j ] = make_qkxs_quants (16 , -4 , 3 , x + 16 * j , weights , L + 16 * j , Faux , true);
1302+ scales [j ] = make_qkxs_quants (16 , -4 , 3 , x + 16 * j , weights , L + 16 * j , Laux , Faux , true);
12511303 // scales[j] = make_q3_quants(16, 4, x + 16*j, L + 16*j, true);
12521304 float scale = fabsf (scales [j ]);
12531305 if (scale > amax ) {
@@ -1367,6 +1419,7 @@ static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restri
13671419 const int nb = n_per_row / QK_K ;
13681420
13691421 int8_t L [QK_K ];
1422+ int8_t Laux [16 ];
13701423 float scales [QK_K / 16 ];
13711424 float weight [16 ];
13721425 float sw [QK_K / 16 ];
@@ -1391,14 +1444,14 @@ static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restri
13911444 sw [j ] = sumw ;
13921445
13931446 // scales[j] = make_qx_quants(16, 4, x + 16*j, L + 16*j, 1, weight);
1394- scales [j ] = make_qkxs_quants (16 , -4 , 3 , x + 16 * j , weight , L + 16 * j , Faux , true);
1447+ scales [j ] = make_qkxs_quants (16 , -4 , 3 , x + 16 * j , weight , L + 16 * j , Laux , Faux , true);
13951448
13961449 }
13971450
13981451 memset (y [i ].scales , 0 , 12 );
13991452
14001453 // float d_block = make_qx_quants(QK_K/16, 32, scales, Ls, 1, sw);
1401- float d_block = make_qkxs_quants (QK_K /16 , -32 , 31 , scales , sw , Ls , Faux , true);
1454+ float d_block = make_qkxs_quants (QK_K /16 , -32 , 31 , scales , sw , Ls , Laux , Faux , true);
14021455 for (int j = 0 ; j < QK_K /16 ; ++ j ) {
14031456 int l = Ls [j ];
14041457 if (j < 8 ) {
@@ -4856,11 +4909,11 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
48564909 for (int j = 0 ; j < block_size ; ++ j ) weight [j ] = sqrtf (sigma2 + xb [j ]* xb [j ]);
48574910 // for (int j = 0; j < block_size; ++j) weight[j] = 1;
48584911 }
4859- float amax = 0 , max = 0 ;
4912+ float amax = 0 ;
48604913 for (int j = 0 ; j < block_size ; ++ j ) {
48614914 float ax = fabsf (xb [j ]);
48624915 if (ax > amax ) {
4863- amax = ax ; max = xb [ j ];
4916+ amax = ax ;
48644917 }
48654918 }
48664919 if (amax < GROUP_MAX_EPS ) {
0 commit comments