Skip to content

Commit 5b0d3f6

Browse files
committed
Automatically determine if bias error is significant
1 parent c93131c commit 5b0d3f6

File tree

1 file changed

+48
-4
lines changed

1 file changed

+48
-4
lines changed

src/llama-quant.cpp

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -637,6 +637,8 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
637637
float bpw;
638638
size_t bytes;
639639
double error;
640+
double mse = 0.0;
641+
double proj = 0.0;
640642
};
641643

642644
struct tensor_info {
@@ -1340,9 +1342,11 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
13401342
const ggml_type tensor_types = compatible_candidates[i];
13411343
const auto bpw = (float)tensor_bpw(tensor, tensor_types);
13421344
const size_t bytes = tensor_bytes(tensor, tensor_types);
1345+
double mse = 0.0;
1346+
double proj = 0.0;
13431347
const auto err = estimate_error(tensor, tensor_types, f32_sample, rows_sample, values, activations,
1344-
tl_quantized_buffer, tl_dequantized_buffer, tensor_lambda, slice_lambda);
1345-
eval_candidates[i] = candidate_types{ tensor_types, bpw, bytes, err };
1348+
tl_quantized_buffer, tl_dequantized_buffer, tensor_lambda, slice_lambda, &mse, &proj);
1349+
eval_candidates[i] = candidate_types{ tensor_types, bpw, bytes, err, mse, proj };
13461350
}
13471351
});
13481352
}
@@ -1354,8 +1358,48 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
13541358
check_signal_handler(all);
13551359
}
13561360

1357-
for (auto &c : eval_candidates) {
1358-
if (c.bytes > 0) { info.candidate.push_back(c); }
1361+
// Check if biasing is needed
1362+
bool bias_needed = false;
1363+
if (!lambdas.empty()) {
1364+
int min_mse = -1;
1365+
int min_bias = -1;
1366+
{
1367+
double best_mse = std::numeric_limits<double>::infinity();
1368+
double best_err = std::numeric_limits<double>::infinity();
1369+
for (int i = 0; i < (int)eval_candidates.size(); ++i) {
1370+
const auto & c = eval_candidates[i];
1371+
if (c.bytes == 0) { continue; }
1372+
if (c.mse < best_mse) {
1373+
best_mse = c.mse;
1374+
min_mse = i;
1375+
}
1376+
if (c.error < best_err) {
1377+
best_err = c.error;
1378+
min_bias = i;
1379+
}
1380+
}
1381+
}
1382+
1383+
if (min_mse != min_bias) {
1384+
bias_needed = true;
1385+
} else {
1386+
double max_rel_bias = 0.0;
1387+
for (const auto & c : eval_candidates) {
1388+
if (c.bytes == 0) { continue; }
1389+
const double mse = std::max(c.mse, epsilon);
1390+
const double bias_term = std::max(0.0, c.error - c.mse);
1391+
const double rel = bias_term / mse;
1392+
max_rel_bias = std::max(rel, max_rel_bias);
1393+
}
1394+
1395+
bias_needed = max_rel_bias >= 0.5; // >= 50% of MSE?
1396+
}
1397+
}
1398+
1399+
for (auto & c : eval_candidates) {
1400+
if (c.bytes == 0) { continue; }
1401+
const double final_err = bias_needed ? c.error : c.mse;
1402+
info.candidate.push_back(candidate_types{ c.type, c.bpw, c.bytes, final_err, c.mse, c.proj });
13591403
}
13601404

13611405
if (info.candidate.empty()) {

0 commit comments

Comments
 (0)