@@ -661,8 +661,6 @@ static inline int compare_fractions_desc(const void * a, const void * b) {
661661// exhaustive search with cumulative sums
662662// Need Faux to have room for n*(max(abs(nmin), abs(nmax))) fractions
663663static float make_qkxs_quants (int n , int nmin , int nmax , const float * restrict x , const float * restrict weights , int8_t * restrict L , int8_t * restrict Laux , struct fraction * restrict Faux , bool signed_scale ) {
664- const int orig_nmin = nmin ;
665- const int orig_nmax = nmax ;
666664 float max = x [0 ];
667665 float min = x [0 ];
668666 float w_amax = weights [0 ] * fabsf (x [0 ]);
@@ -2143,6 +2141,8 @@ static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restri
21432141
21442142 float weight [QK4_0 ];
21452143 int8_t L [QK4_0 ];
2144+ int8_t Laux [QK4_0 ];
2145+ struct fraction Faux [8 * QK4_0 ];
21462146
21472147 float sum_x2 = 0 ;
21482148 for (int j = 0 ; j < n_per_row ; ++ j ) sum_x2 += x [j ]* x [j ];
@@ -2153,7 +2153,7 @@ static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restri
21532153 const float * xb = x + QK4_0 * ib ;
21542154 const float * qw = quant_weights + QK4_0 * ib ;
21552155 for (int j = 0 ; j < QK4_0 ; ++ j ) weight [j ] = qw [j ] * sqrtf (sigma2 + xb [j ]* xb [j ]);
2156- float d = make_qx_quants (QK4_0 , 8 , xb , L , 1 , weight );
2156+ float d = make_qkxs_quants (QK4_0 , - 8 , 7 , xb , weight , L , Laux , Faux , true );
21572157 y [ib ].d = GGML_FP32_TO_FP16 (d );
21582158 for (int j = 0 ; j < 16 ; ++ j ) {
21592159 y [ib ].qs [j ] = L [j ] | (L [j + 16 ] << 4 );
@@ -2231,6 +2231,8 @@ static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restri
22312231
22322232 float weight [QK5_0 ];
22332233 int8_t L [QK5_0 ];
2234+ int8_t Laux [QK5_0 ];
2235+ struct fraction Faux [16 * QK5_0 ];
22342236
22352237 float sum_x2 = 0 ;
22362238 for (int j = 0 ; j < n_per_row ; ++ j ) sum_x2 += x [j ]* x [j ];
@@ -2241,7 +2243,7 @@ static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restri
22412243 const float * xb = x + QK5_0 * ib ;
22422244 const float * qw = quant_weights + QK5_0 * ib ;
22432245 for (int j = 0 ; j < QK5_0 ; ++ j ) weight [j ] = qw [j ] * sqrtf (sigma2 + xb [j ]* xb [j ]);
2244- float d = make_qx_quants (QK5_0 , 16 , xb , L , 1 , weight );
2246+ float d = make_qkxs_quants (QK5_0 , - 16 , 15 , xb , weight , L , Laux , Faux , true );
22452247 y [ib ].d = GGML_FP32_TO_FP16 (d );
22462248
22472249 uint32_t qh = 0 ;
@@ -2403,6 +2405,74 @@ void quantize_row_tq1_0_ref(const float * restrict x, block_tq1_0 * restrict y,
24032405 }
24042406}
24052407
2408+ static void quantize_row_tq1_0_impl (const float * restrict x , block_tq1_0 * restrict y , int64_t n_per_row , const float * quant_weights ) {
2409+ if (!quant_weights ) {
2410+ quantize_row_tq1_0_ref (x , y , n_per_row );
2411+ return ;
2412+ }
2413+
2414+ float weight [QK_K ];
2415+ int8_t L [QK_K ];
2416+ int8_t Laux [QK_K ];
2417+ struct fraction Faux [1 * QK_K ];
2418+
2419+ float sum_x2 = 0 ;
2420+ for (int j = 0 ; j < n_per_row ; ++ j ) { sum_x2 += x [j ]* x [j ]; }
2421+ float sigma2 = sum_x2 /n_per_row ;
2422+
2423+ const int64_t nb = n_per_row /QK_K ;
2424+ for (int ib = 0 ; ib < nb ; ++ ib ) {
2425+ const float * xb = x + QK_K * ib ;
2426+ const float * qw = quant_weights + QK_K * ib ;
2427+ const int8_t * Lptr = L ;
2428+ for (int j = 0 ; j < QK_K ; ++ j ) { weight [j ] = qw [j ] * sqrtf (sigma2 + xb [j ]* xb [j ]); }
2429+ float d = make_qkxs_quants (QK_K , -1 , 1 , xb , weight , L , Laux , Faux , false);
2430+ y [ib ].d = GGML_FP32_TO_FP16 (d );
2431+
2432+ // 5 elements per byte, along 32 bytes
2433+ for (size_t j = 0 ; j < sizeof (y -> qs ) - sizeof (y -> qs ) % 32 ; j += 32 ) {
2434+ for (size_t m = 0 ; m < 32 ; ++ m ) {
2435+ uint8_t q = 0 ;
2436+ for (size_t n = 0 ; n < 5 ; ++ n ) {
2437+ q *= 3 ;
2438+ q += Lptr [m + n * 32 ];
2439+ }
2440+ // ceiling division (243 == pow(3, 5))
2441+ q = ((uint16_t )q * 256 + (243 - 1 )) / 243 ;
2442+ y [ib ].qs [j + m ] = q ;
2443+ }
2444+ Lptr += 5 * 32 ;
2445+ }
2446+ // along 16 bytes
2447+ for (size_t j = sizeof (y -> qs ) - sizeof (y -> qs ) % 32 ; j < sizeof (y -> qs ); j += 16 ) {
2448+ for (size_t m = 0 ; m < 16 ; ++ m ) {
2449+ uint8_t q = 0 ;
2450+ for (size_t n = 0 ; n < 5 ; ++ n ) {
2451+ q *= 3 ;
2452+ q += Lptr [m + n * 16 ];
2453+ }
2454+ // ceiling division (243 == pow(3, 5))
2455+ q = ((uint16_t )q * 256 + (243 - 1 )) / 243 ;
2456+ y [ib ].qs [j + m ] = q ;
2457+ }
2458+ Lptr += 5 * 16 ;
2459+ }
2460+ // 4 elements per byte
2461+ for (size_t j = 0 ; j < sizeof (y -> qh ); ++ j ) {
2462+ uint8_t q = 0 ;
2463+ for (size_t m = 0 ; m < 4 ; ++ m ) {
2464+ q *= 3 ;
2465+ q += Lptr [j + m * sizeof (y -> qh )];
2466+ }
2467+ // shift the first value to the most significant trit
2468+ q *= 3 ;
2469+ // ceiling division (243 == pow(3, 5))
2470+ q = ((uint16_t )q * 256 + (243 - 1 )) / 243 ;
2471+ y [ib ].qh [j ] = q ;
2472+ }
2473+ }
2474+ }
2475+
24062476void quantize_row_tq2_0_ref (const float * restrict x , block_tq2_0 * restrict y , int64_t k ) {
24072477 assert (k % QK_K == 0 );
24082478 const int64_t nb = k / QK_K ;
@@ -2435,17 +2505,69 @@ void quantize_row_tq2_0_ref(const float * restrict x, block_tq2_0 * restrict y,
24352505 }
24362506}
24372507
2508+
2509+ static void quantize_row_tq2_0_impl (const float * restrict x , block_tq2_0 * restrict y , int64_t n_per_row , const float * quant_weights ) {
2510+ if (!quant_weights ) {
2511+ quantize_row_tq2_0_ref (x , y , n_per_row );
2512+ return ;
2513+ }
2514+
2515+ float weight [QK_K ];
2516+ int8_t L [QK_K ];
2517+ int8_t Laux [QK_K ];
2518+ struct fraction Faux [2 * QK_K ];
2519+
2520+ float sum_x2 = 0 ;
2521+ for (int j = 0 ; j < n_per_row ; ++ j ) { sum_x2 += x [j ]* x [j ]; }
2522+ float sigma2 = sum_x2 /n_per_row ;
2523+
2524+ const int64_t nb = n_per_row /QK_K ;
2525+ for (int ib = 0 ; ib < nb ; ++ ib ) {
2526+ const float * xb = x + QK_K * ib ;
2527+ const float * qw = quant_weights + QK_K * ib ;
2528+ for (int j = 0 ; j < QK_K ; ++ j ) { weight [j ] = qw [j ] * sqrtf (sigma2 + xb [j ]* xb [j ]); }
2529+ float d = make_qkxs_quants (QK_K , -1 , 2 , xb , weight , L , Laux , Faux , true);
2530+ y [ib ].d = GGML_FP32_TO_FP16 (d );
2531+
2532+ for (size_t j = 0 ; j < sizeof (y -> qs ); j += 32 ) {
2533+ for (size_t m = 0 ; m < 32 ; ++ m ) {
2534+ uint8_t q = 0 ;
2535+ for (size_t n = 0 ; n < 4 ; ++ n ) {
2536+ q += (L [4 * j + m + n * 32 ] & 3 ) << (2 * n );
2537+ }
2538+ y [ib ].qs [j + m ] = q ;
2539+ }
2540+ }
2541+ }
2542+ }
2543+
24382544size_t quantize_tq1_0 (const float * restrict src , void * restrict dst , int64_t nrow , int64_t n_per_row , const float * quant_weights ) {
2439- (void )quant_weights ; // not used
2440- const size_t row_size = ggml_row_size (GGML_TYPE_TQ1_0 , n_per_row );
2441- quantize_row_tq1_0_ref (src , dst , (int64_t )nrow * n_per_row );
2545+ if (!quant_weights ) {
2546+ quantize_row_tq1_0_ref (src , dst , (int64_t )nrow * n_per_row );
2547+ return nrow * ggml_row_size (GGML_TYPE_TQ1_0 , n_per_row );
2548+ }
2549+ size_t row_size = ggml_row_size (GGML_TYPE_TQ1_0 , n_per_row );
2550+ char * qrow = (char * )dst ;
2551+ for (int64_t row = 0 ; row < nrow ; ++ row ) {
2552+ quantize_row_tq1_0_impl (src , (block_tq1_0 * )qrow , n_per_row , quant_weights );
2553+ src += n_per_row ;
2554+ qrow += row_size ;
2555+ }
24422556 return nrow * row_size ;
24432557}
24442558
24452559size_t quantize_tq2_0 (const float * restrict src , void * restrict dst , int64_t nrow , int64_t n_per_row , const float * quant_weights ) {
2446- (void )quant_weights ; // not used
2447- const size_t row_size = ggml_row_size (GGML_TYPE_TQ2_0 , n_per_row );
2448- quantize_row_tq2_0_ref (src , dst , (int64_t )nrow * n_per_row );
2560+ if (!quant_weights ) {
2561+ quantize_row_tq2_0_ref (src , dst , (int64_t )nrow * n_per_row );
2562+ return nrow * ggml_row_size (GGML_TYPE_TQ2_0 , n_per_row );
2563+ }
2564+ size_t row_size = ggml_row_size (GGML_TYPE_TQ2_0 , n_per_row );
2565+ char * qrow = (char * )dst ;
2566+ for (int64_t row = 0 ; row < nrow ; ++ row ) {
2567+ quantize_row_tq2_0_impl (src , (block_tq2_0 * )qrow , n_per_row , quant_weights );
2568+ src += n_per_row ;
2569+ qrow += row_size ;
2570+ }
24492571 return nrow * row_size ;
24502572}
24512573
0 commit comments