Skip to content

Commit 044fa78

Browse files
committed
Fix trimming logic
1 parent 84ada44 commit 044fa78

File tree

1 file changed

+16
-11
lines changed

1 file changed

+16
-11
lines changed

src/llama-quant.cpp

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -849,8 +849,12 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
849849
};
850850

851851
auto delete_bpw_state = [&] {
852-
LLAMA_LOG_INFO("%s: deleting %s\n", func, checkpoint_file.c_str());
853-
std::remove(checkpoint_file.c_str());
852+
std::ifstream ifs(checkpoint_file);
853+
if (ifs.good()) {
854+
LLAMA_LOG_INFO("%s: deleting %s\n", func, checkpoint_file.c_str());
855+
std::remove(checkpoint_file.c_str());
856+
}
857+
854858
};
855859

856860
auto check_signal_handler = [&](const std::vector<tensor_info> & all_vec) {
@@ -988,14 +992,16 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
988992
}
989993

990994
// Compute error per slice with trimmed aggregation
991-
auto trimmed_sum = [](std::vector<double> & v) -> double {
995+
auto trimmed_mean = [](std::vector<double> & v) -> double {
992996
const int64_t n = (int64_t)v.size();
993997
if (n == 0) { return 0.0; }
994-
if (n < 50) { return std::accumulate(v.begin(), v.end(), 0.0); } // use all samples for small datasets
995-
996-
int64_t k = (int64_t) std::floor(0.025 * (double)n); // trim 2.5% from each tail of the distribution
998+
double sum = std::accumulate(v.begin(), v.end(), 0.0);
999+
if (n < 50) { return sum / (double)n; } // too few elements to trim
1000+
int64_t k = (int64_t) std::floor(0.025 * (double)n); // trim 5% (2.5% each side)
9971001
std::sort(v.begin(), v.end());
998-
return std::accumulate(v.begin() + k, v.begin() + (n - k), 0.0);
1002+
const auto num = (double)(n - 2 * k);
1003+
sum = std::accumulate(v.begin() + k, v.begin() + (n - k), 0.0);
1004+
return sum / std::max(1.0, num);
9991005
};
10001006

10011007
size_t off = 0;
@@ -1028,7 +1034,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
10281034
}
10291035

10301036
const double denom_x = row_sq_norm[ridx];
1031-
const double m_norm = w_mse / (denom_x + epsilon);
1037+
const double m_norm = w_mse / (denom_x + epsilon);
10321038
row_mse_norm.push_back(std::isfinite(m_norm) ? m_norm : infinity);
10331039

10341040
if (a) {
@@ -1044,9 +1050,8 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
10441050
off += (size_t)n_per_row;
10451051
}
10461052

1047-
const double scale_rows = (double)nrows / std::max(1.0, (double)rs);
1048-
const double slice_mse = trimmed_sum(row_mse_norm) * scale_rows;
1049-
const double slice_proj = a ? trimmed_sum(row_proj_norm) * scale_rows : 0.0;
1053+
const double slice_mse = trimmed_mean(row_mse_norm) * (double)nrows;
1054+
const double slice_proj = a ? trimmed_mean(row_proj_norm) * (double)nrows : 0.0;
10501055

10511056
total_mse += slice_mse;
10521057
total_proj += slice_proj;

0 commit comments

Comments
 (0)