@@ -902,26 +902,6 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
902902 return std::isfinite (total_err) ? total_err : infinity;
903903 };
904904
905- // Scaling factor to increase lambda when activations are concentrated
906- auto directional_scale = [&](const float * values, const float * activations, int64_t n_per_row) {
907- if (!activations) { return 1 .0f ; }
908- double sum_v = 0.0 ;
909- double sum_aw2 = 0.0 ;
910- double sum_a2 = 0.0 ;
911- for (int64_t j = 0 ; j < n_per_row; ++j) {
912- const double v = values ? std::max (0 .0f , values[j]) : 1.0 ;
913- const double a = activations[j];
914- sum_v += v;
915- sum_aw2 += v * a * a;
916- sum_a2 += a * a;
917- }
918- const double rms_a = std::sqrt (sum_a2 / std::max (1.0 , (double )n_per_row));
919- const double denom = std::sqrt (std::max (epsilon, sum_v)) * std::max (epsilon, rms_a);
920- const double scale = denom > 0.0 ? std::sqrt (sum_aw2) / denom : 1.0 ;
921-
922- return (float )std::clamp (scale, 0.5 , 2.0 );
923- };
924-
925905 // Higher precision but much longer to compute
926906 auto precise_lambda = [&](const ggml_tensor * t,
927907 const std::vector<float > & f32_sample,
@@ -979,11 +959,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
979959 if (ratios.empty ()) { return 0 .0f ; }
980960
981961 std::nth_element (ratios.begin (), ratios.begin () + ratios.size () / 2 , ratios.end ());
982- double lambda = ratios[ratios.size () / 2 ];
983-
984- const float scale = directional_scale (values, activations, n_per_row);
985- lambda *= scale;
986- lambda = std::clamp (lambda, 0.0 , 8.0 );
962+ const double lambda = std::clamp (ratios[ratios.size () / 2 ], 0.0 , 8.0 );
987963
988964 return (float )lambda;
989965 };
@@ -1007,8 +983,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
1007983 double base = 1.0 - s * s / (d * s2 + epsilon);
1008984 base = std::clamp (base, 0.0 , 1.0 );
1009985
1010- const double scale = directional_scale (values, activations, n_per_row);
1011- const double lambda = std::clamp (base * scale, 0.0 , 1.0 ) * 8.0 ;
986+ const double lambda = std::clamp (base, 0.0 , 1.0 ) * 8.0 ;
1012987
1013988 return (float )lambda;
1014989 };
@@ -1159,8 +1134,11 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
11591134 {
11601135 const float * values = values_sample.empty () ? nullptr : values_sample.data ();
11611136 const float * activations = activations_sample.empty () ? nullptr : activations_sample.data ();
1162- bias_lambda = params->precise_lambda ? precise_lambda (t, f32_sample, sample_rows_per_slice, values, activations, compatible_candidates) :
1163- fast_lambda (values, activations, n_per_row);
1137+ if (params->bpw_bias == 1 ) {
1138+ bias_lambda = fast_lambda (values, activations, n_per_row);
1139+ } else if (params->bpw_bias == 2 ) {
1140+ bias_lambda = precise_lambda (t, f32_sample, sample_rows_per_slice, values, activations, compatible_candidates);
1141+ }
11641142 }
11651143
11661144 // Now evaluate candidates
@@ -1656,7 +1634,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
16561634 } else {
16571635 LLAMA_LOG_WARN (" %s: imatrix without activations provided, target bpw quantization will be less accurate - " , __func__);
16581636 }
1659- LLAMA_LOG_INFO (" using %s\n " , params->precise_lambda ? " precise lambda (slow)" : " fast lambda" );
1637+ const char * msg[] = {" no bias (MSE only)" , " fast (default)" , " precise (slow)" };
1638+ LLAMA_LOG_INFO (" using %s error estimation\n " , msg[params->bpw_bias ]);
16601639 LLAMA_LOG_INFO (" %s: computing tensor quantization mix to achieve %.4f bpw\n " , __func__, params->target_bpw );
16611640 bpw_overrides = target_bpw_type (ml, read_data, model, tensors, mapped, values_data, activations_data, params, nthread);
16621641 } else {
@@ -1967,7 +1946,7 @@ llama_model_quantize_params llama_model_quantize_default_params() {
19671946 /* .tensor_type =*/ nullptr ,
19681947 /* .prune_layers =*/ nullptr ,
19691948 /* .target_bpw =*/ -1 .0f ,
1970- /* .precise_lambda =*/ false
1949+ /* .bpw_bias =*/ 1
19711950 };
19721951
19731952 return result;
0 commit comments