@@ -849,8 +849,12 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
849849 };
850850
851851 auto delete_bpw_state = [&] {
852- LLAMA_LOG_INFO (" %s: deleting %s\n " , func, checkpoint_file.c_str ());
853- std::remove (checkpoint_file.c_str ());
852+ std::ifstream ifs (checkpoint_file);
853+ if (ifs.good ()) {
854+ LLAMA_LOG_INFO (" %s: deleting %s\n " , func, checkpoint_file.c_str ());
855+ std::remove (checkpoint_file.c_str ());
856+ }
857+
854858 };
855859
856860 auto check_signal_handler = [&](const std::vector<tensor_info> & all_vec) {
@@ -988,14 +992,16 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
988992 }
989993
990994 // Compute error per slice with trimmed aggregation
991- auto trimmed_sum = [](std::vector<double > & v) -> double {
995+ auto trimmed_mean = [](std::vector<double > & v) -> double {
992996 const int64_t n = (int64_t )v.size ();
993997 if (n == 0 ) { return 0.0 ; }
994- if (n < 50 ) { return std::accumulate (v.begin (), v.end (), 0.0 ); } // use all samples for small datasets
995-
996- int64_t k = (int64_t ) std::floor (0.025 * (double )n); // trim 2.5% from each tail of the distribution
998+ double sum = std::accumulate (v.begin (), v.end (), 0.0 );
999+ if (n < 50 ) { return sum / ( double )n; } // too few elements to trim
1000+ int64_t k = (int64_t ) std::floor (0.025 * (double )n); // trim 5% ( 2.5% each side)
9971001 std::sort (v.begin (), v.end ());
998- return std::accumulate (v.begin () + k, v.begin () + (n - k), 0.0 );
1002+ const auto num = (double )(n - 2 * k);
1003+ sum = std::accumulate (v.begin () + k, v.begin () + (n - k), 0.0 );
1004+ return sum / std::max (1.0 , num);
9991005 };
10001006
10011007 size_t off = 0 ;
@@ -1028,7 +1034,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
10281034 }
10291035
10301036 const double denom_x = row_sq_norm[ridx];
1031- const double m_norm = w_mse / (denom_x + epsilon);
1037+ const double m_norm = w_mse / (denom_x + epsilon);
10321038 row_mse_norm.push_back (std::isfinite (m_norm) ? m_norm : infinity);
10331039
10341040 if (a) {
@@ -1044,9 +1050,8 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
10441050 off += (size_t )n_per_row;
10451051 }
10461052
1047- const double scale_rows = (double )nrows / std::max (1.0 , (double )rs);
1048- const double slice_mse = trimmed_sum (row_mse_norm) * scale_rows;
1049- const double slice_proj = a ? trimmed_sum (row_proj_norm) * scale_rows : 0.0 ;
1053+ const double slice_mse = trimmed_mean (row_mse_norm) * (double )nrows;
1054+ const double slice_proj = a ? trimmed_mean (row_proj_norm) * (double )nrows : 0.0 ;
10501055
10511056 total_mse += slice_mse;
10521057 total_proj += slice_proj;
0 commit comments