@@ -637,6 +637,8 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
637637 float bpw;
638638 size_t bytes;
639639 double error;
640+ double mse = 0.0 ;
641+ double proj = 0.0 ;
640642 };
641643
642644 struct tensor_info {
@@ -1340,9 +1342,11 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
13401342 const ggml_type tensor_types = compatible_candidates[i];
13411343 const auto bpw = (float )tensor_bpw (tensor, tensor_types);
13421344 const size_t bytes = tensor_bytes (tensor, tensor_types);
1345+ double mse = 0.0 ;
1346+ double proj = 0.0 ;
13431347 const auto err = estimate_error (tensor, tensor_types, f32_sample, rows_sample, values, activations,
1344- tl_quantized_buffer, tl_dequantized_buffer, tensor_lambda, slice_lambda);
1345- eval_candidates[i] = candidate_types{ tensor_types, bpw, bytes, err };
1348+ tl_quantized_buffer, tl_dequantized_buffer, tensor_lambda, slice_lambda, &mse, &proj );
1349+ eval_candidates[i] = candidate_types{ tensor_types, bpw, bytes, err, mse, proj };
13461350 }
13471351 });
13481352 }
@@ -1354,8 +1358,48 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
13541358 check_signal_handler (all);
13551359 }
13561360
1357- for (auto &c : eval_candidates) {
1358- if (c.bytes > 0 ) { info.candidate .push_back (c); }
1361+ // Check if biasing is needed
1362+ bool bias_needed = false ;
1363+ if (!lambdas.empty ()) {
1364+ int min_mse = -1 ;
1365+ int min_bias = -1 ;
1366+ {
1367+ double best_mse = std::numeric_limits<double >::infinity ();
1368+ double best_err = std::numeric_limits<double >::infinity ();
1369+ for (int i = 0 ; i < (int )eval_candidates.size (); ++i) {
1370+ const auto & c = eval_candidates[i];
1371+ if (c.bytes == 0 ) { continue ; }
1372+ if (c.mse < best_mse) {
1373+ best_mse = c.mse ;
1374+ min_mse = i;
1375+ }
1376+ if (c.error < best_err) {
1377+ best_err = c.error ;
1378+ min_bias = i;
1379+ }
1380+ }
1381+ }
1382+
1383+ if (min_mse != min_bias) {
1384+ bias_needed = true ;
1385+ } else {
1386+ double max_rel_bias = 0.0 ;
1387+ for (const auto & c : eval_candidates) {
1388+ if (c.bytes == 0 ) { continue ; }
1389+ const double mse = std::max (c.mse , epsilon);
1390+ const double bias_term = std::max (0.0 , c.error - c.mse );
1391+ const double rel = bias_term / mse;
1392+ max_rel_bias = std::max (rel, max_rel_bias);
1393+ }
1394+
1395+ bias_needed = max_rel_bias >= 0.5 ; // >= 50% of MSE?
1396+ }
1397+ }
1398+
1399+ for (auto & c : eval_candidates) {
1400+ if (c.bytes == 0 ) { continue ; }
1401+ const double final_err = bias_needed ? c.error : c.mse ;
1402+ info.candidate .push_back (candidate_types{ c.type , c.bpw , c.bytes , final_err, c.mse , c.proj });
13591403 }
13601404
13611405 if (info.candidate .empty ()) {
0 commit comments