Skip to content

Commit 04c07b3

Browse files
committed
Add better control over MSE and directional bias computation
1 parent 7d04050 commit 04c07b3

File tree

3 files changed

+39
-35
lines changed

3 files changed

+39
-35
lines changed

include/llama.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ extern "C" {
365365
void * tensor_types; // pointer to vector containing tensor types
366366
void * prune_layers; // pointer to vector containing layer indices to prune
367367
float target_bpw; // target bits per weight (bpw)
368-
bool precise_lambda; // use precise_lambda calculation - slow computation but very accurate
368+
int32_t bpw_bias; // type of error bias to use: 0 = no bias (MSE only), 1 = fast (default), 2 = precise (slow)
369369
} llama_model_quantize_params;
370370

371371
typedef struct llama_logit_bias {

src/llama-quant.cpp

Lines changed: 10 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -902,26 +902,6 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
902902
return std::isfinite(total_err) ? total_err : infinity;
903903
};
904904

905-
// Scaling factor to increase lambda when activations are concentrated
906-
auto directional_scale = [&](const float * values, const float * activations, int64_t n_per_row) {
907-
if (!activations) { return 1.0f; }
908-
double sum_v = 0.0;
909-
double sum_aw2 = 0.0;
910-
double sum_a2 = 0.0;
911-
for (int64_t j = 0; j < n_per_row; ++j) {
912-
const double v = values ? std::max(0.0f, values[j]) : 1.0;
913-
const double a = activations[j];
914-
sum_v += v;
915-
sum_aw2 += v * a * a;
916-
sum_a2 += a * a;
917-
}
918-
const double rms_a = std::sqrt(sum_a2 / std::max(1.0, (double)n_per_row));
919-
const double denom = std::sqrt(std::max(epsilon, sum_v)) * std::max(epsilon, rms_a);
920-
const double scale = denom > 0.0 ? std::sqrt(sum_aw2) / denom : 1.0;
921-
922-
return (float)std::clamp(scale, 0.5, 2.0);
923-
};
924-
925905
// Higher precision but much longer to compute
926906
auto precise_lambda = [&](const ggml_tensor * t,
927907
const std::vector<float> & f32_sample,
@@ -979,11 +959,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
979959
if (ratios.empty()) { return 0.0f; }
980960

981961
std::nth_element(ratios.begin(), ratios.begin() + ratios.size() / 2, ratios.end());
982-
double lambda = ratios[ratios.size() / 2];
983-
984-
const float scale = directional_scale(values, activations, n_per_row);
985-
lambda *= scale;
986-
lambda = std::clamp(lambda, 0.0, 8.0);
962+
const double lambda = std::clamp(ratios[ratios.size() / 2], 0.0, 8.0);
987963

988964
return (float)lambda;
989965
};
@@ -1007,8 +983,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
1007983
double base = 1.0 - s * s / (d * s2 + epsilon);
1008984
base = std::clamp(base, 0.0, 1.0);
1009985

1010-
const double scale = directional_scale(values, activations, n_per_row);
1011-
const double lambda = std::clamp(base * scale, 0.0, 1.0) * 8.0;
986+
const double lambda = std::clamp(base, 0.0, 1.0) * 8.0;
1012987

1013988
return (float)lambda;
1014989
};
@@ -1159,8 +1134,11 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
11591134
{
11601135
const float * values = values_sample.empty() ? nullptr : values_sample.data();
11611136
const float * activations = activations_sample.empty() ? nullptr : activations_sample.data();
1162-
bias_lambda = params->precise_lambda ? precise_lambda(t, f32_sample, sample_rows_per_slice, values, activations, compatible_candidates) :
1163-
fast_lambda(values, activations, n_per_row);
1137+
if (params->bpw_bias == 1) {
1138+
bias_lambda = fast_lambda(values, activations, n_per_row);
1139+
} else if (params->bpw_bias == 2) {
1140+
bias_lambda = precise_lambda(t, f32_sample, sample_rows_per_slice, values, activations, compatible_candidates);
1141+
}
11641142
}
11651143

11661144
// Now evaluate candidates
@@ -1656,7 +1634,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
16561634
} else {
16571635
LLAMA_LOG_WARN("%s: imatrix without activations provided, target bpw quantization will be less accurate - ", __func__);
16581636
}
1659-
LLAMA_LOG_INFO("using %s\n", params->precise_lambda ? "precise lambda (slow)" : "fast lambda");
1637+
const char* msg[] = {"no bias (MSE only)", "fast (default)", "precise (slow)"};
1638+
LLAMA_LOG_INFO("using %s error estimation\n", msg[params->bpw_bias]);
16601639
LLAMA_LOG_INFO("%s: computing tensor quantization mix to achieve %.4f bpw\n", __func__, params->target_bpw);
16611640
bpw_overrides = target_bpw_type(ml, read_data, model, tensors, mapped, values_data, activations_data, params, nthread);
16621641
} else {
@@ -1967,7 +1946,7 @@ llama_model_quantize_params llama_model_quantize_default_params() {
19671946
/*.tensor_type =*/ nullptr,
19681947
/*.prune_layers =*/ nullptr,
19691948
/*.target_bpw =*/ -1.0f,
1970-
/*.precise_lambda =*/ false
1949+
/*.bpw_bias =*/ 1
19711950
};
19721951

19731952
return result;

tools/quantize/quantize.cpp

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ static void usage(const char * executable) {
134134
printf(" Advanced option to remove all tensors from the given layers\n");
135135
printf(" --target-bpw: target bits per weight (bpw). Must be a positive number between 0.0 and 16.0\n");
136136
printf(" Advanced option to automatically select quantization types to achieve a total bits per weight (bpw) target\n");
137-
printf(" --precise-lambda: given a target bpw, use a high-precision error computation at the expense of longer processing times\n");
137+
printf(" --bpw_bias: type of error bias to use: 0 = no bias (MSE only), 1 = fast (default), 2 = precise (slow)\n");
138138
printf(" --keep-split: will generate quantized model in the same shards as input\n");
139139
printf(" --override-kv KEY=TYPE:VALUE\n");
140140
printf(" Advanced option to override model metadata by key in the quantized model. May be specified multiple times.\n");
@@ -496,6 +496,27 @@ static bool parse_target_bpw(const char * data, float & target_bpw) {
496496
return true;
497497
}
498498

499+
static bool parse_bpw_bias(const char * data, int & bpw_bias) {
500+
if (!data) {
501+
printf("\n%s: error bias type not provided\n\n", __func__);
502+
return false;
503+
}
504+
505+
try {
506+
bpw_bias = std::stoi(data);
507+
if (bpw_bias < 0 || bpw_bias > 2) {
508+
printf("\n%s: error bias type must be one of 0 (no bias, MSE only), 1 (fast), or 2 (precise, but slow)\n\n", __func__);
509+
return false;
510+
}
511+
}
512+
catch (const std::exception & e) {
513+
printf("\n%s: '%s' is not valid. Target bits per weight (bpw) must be a positive number between 0.0 and 16.0\n\n", __func__, data);
514+
return false;
515+
}
516+
517+
return true;
518+
}
519+
499520
int main(int argc, char ** argv) {
500521
if (argc < 3) {
501522
usage(argv[0]);
@@ -510,6 +531,7 @@ int main(int argc, char ** argv) {
510531
std::vector<tensor_quantization> tensor_types;
511532
std::vector<int> prune_layers;
512533
float target_bpw = -1.0f;
534+
int bpw_bias = 1;
513535

514536
for (; arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0; arg_idx++) {
515537
if (strcmp(argv[arg_idx], "--leave-output-tensor") == 0) {
@@ -540,8 +562,11 @@ int main(int argc, char ** argv) {
540562
if (arg_idx == argc-1 || !parse_target_bpw(argv[++arg_idx], target_bpw)) {
541563
usage(argv[0]);
542564
}
543-
} else if (strcmp(argv[arg_idx], "--precise-lambda") == 0) {
544-
params.precise_lambda = true;
565+
} else if (strcmp(argv[arg_idx], "--bpw-bias") == 0) {
566+
if (arg_idx == argc-1 || !parse_bpw_bias(argv[++arg_idx], bpw_bias)) {
567+
usage(argv[0]);
568+
}
569+
params.bpw_bias = bpw_bias;
545570
} else if (strcmp(argv[arg_idx], "--prune-layers") == 0) {
546571
if (arg_idx == argc-1 || !parse_layer_prune(argv[++arg_idx], prune_layers)) {
547572
usage(argv[0]);

0 commit comments

Comments
 (0)