Skip to content

Commit 1a3e9ea

Browse files
committed
Refactor estimate_error()
1 parent a7ee915 commit 1a3e9ea

File tree

1 file changed

+85
-106
lines changed

1 file changed

+85
-106
lines changed

src/llama-quant.cpp

Lines changed: 85 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -737,12 +737,12 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
737737
const int64_t n_per_row = t->ne[0];
738738
const int64_t nrows = t->ne[1];
739739
const int64_t ne2 = t->ne[2] > 0 ? t->ne[2] : 1;
740-
const size_t sample_element_count = f32_sample.size();
741-
const size_t sample_row_count = n_per_row > 0 ? sample_element_count / (size_t)n_per_row : 0;
742-
if (sample_row_count == 0) {
740+
const size_t sample_elems = f32_sample.size();
741+
const size_t sample_rows = n_per_row > 0 ? sample_elems / (size_t)n_per_row : 0;
742+
743+
if (sample_rows == 0) {
743744
if (out_mse) { *out_mse = 0.0; }
744745
if (out_proj) { *out_proj = 0.0; }
745-
746746
return 0.0;
747747
}
748748

@@ -751,105 +751,102 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
751751
expected_rows += (size_t)rows_sample[s];
752752
}
753753

754-
if (expected_rows != sample_row_count) {
754+
if (expected_rows != sample_rows) {
755755
if (out_mse) { *out_mse = infinity; }
756756
if (out_proj) { *out_proj = 0.0; }
757-
758757
return infinity;
759758
}
760759

761760
const size_t row_sz = ggml_row_size(quant_type, n_per_row);
762-
const size_t buffer_sz = row_sz * sample_row_count;
761+
const size_t buf_sz = row_sz * sample_rows;
763762

764-
if (quantized_buffer.size() < buffer_sz) { quantized_buffer.resize(buffer_sz); }
765-
if (dequantized_buffer.size() < sample_element_count) { dequantized_buffer.resize(sample_element_count); }
763+
if (quantized_buffer.size() < buf_sz) { quantized_buffer.resize(buf_sz); }
764+
if (dequantized_buffer.size() < sample_elems) { dequantized_buffer.resize(sample_elems); }
766765

767766
const bool has_values = values_sample != nullptr;
768767
const bool has_activations = activations_sample != nullptr;
769768

770769
// Bias denominators per slice
771-
std::vector<double> bias_denominator_per_slice(ne2, 0.0);
770+
std::vector<double> bias_denom(ne2, 0.0);
772771
if (has_activations) {
773772
for (int64_t s = 0; s < ne2; ++s) {
774-
const float * values = has_values ? values_sample + s * n_per_row : nullptr;
775-
const float * activations = activations_sample + s * n_per_row;
773+
const float * v = has_values ? values_sample + s * n_per_row : nullptr;
774+
const float * a = activations_sample + s * n_per_row;
776775
double denom = 0.0;
777776
for (int64_t j = 0; j < n_per_row; ++j) {
778-
const double w = values ? std::max(0.0f, values[j]) : 1.0;
779-
const double a = activations[j];
780-
denom += w * a * a;
777+
const double w = v ? std::max(0.0f, v[j]) : 1.0;
778+
const double aj = a[j];
779+
denom += w * aj * aj;
781780
}
782781

783-
bias_denominator_per_slice[s] = denom;
782+
bias_denom[s] = denom;
784783
}
785784
}
786785

787-
// Weighted per-row squared norms
788-
std::vector<double> row_sq_norm(sample_row_count, 0.0);
786+
// Row squared norms (weighted if values present)
787+
std::vector<double> row_sq_norm(sample_rows, 0.0);
789788
{
790-
size_t offset = 0;
791-
size_t row_idx = 0;
789+
size_t off = 0;
790+
size_t ridx = 0;
792791
for (int64_t s = 0; s < ne2; ++s) {
793792
const int64_t rs = rows_sample[s];
794793
if (rs == 0) { continue; }
795794

796-
const float * values = has_values ? values_sample + s * n_per_row : nullptr;
797-
for (int64_t r = 0; r < rs; ++r, ++row_idx) {
798-
const float * x = f32_sample.data() + offset;
799-
double rsn = 0.0;
800-
if (values) {
795+
const float * v = has_values ? values_sample + s * n_per_row : nullptr;
796+
for (int64_t r = 0; r < rs; ++r, ++ridx) {
797+
const float * x = f32_sample.data() + off;
798+
double sum = 0.0;
799+
if (v) {
801800
for (int64_t j = 0; j < n_per_row; ++j) {
802-
const double w = std::max(0.0f, values[j]);
801+
const double w = std::max(0.0f, v[j]);
803802
const double xx = x[j];
804-
rsn += w * xx * xx;
803+
sum += w * xx * xx;
805804
}
806805
} else {
807806
for (int64_t j = 0; j < n_per_row; ++j) {
808807
const double xx = x[j];
809-
rsn += xx * xx;
808+
sum += xx * xx;
810809
}
811810
}
812-
row_sq_norm[row_idx] = rsn;
813-
offset += (size_t)n_per_row;
811+
812+
row_sq_norm[ridx] = sum;
813+
off += (size_t)n_per_row;
814814
}
815815
}
816816
}
817817

818-
// Quantize sampled rows per slice -> quantized_buffer
818+
// Quantize per slice into quantized_buffer
819819
{
820-
size_t q_offset = 0;
821-
size_t f_offset = 0;
822-
for (int64_t slice = 0; slice < ne2; ++slice) {
823-
const int64_t rs = rows_sample[slice];
820+
size_t qoff = 0;
821+
size_t foff = 0;
822+
for (int64_t s = 0; s < ne2; ++s) {
823+
const int64_t rs = rows_sample[s];
824824
if (rs == 0) { continue; }
825825

826-
const float * value = has_values ? values_sample + slice * n_per_row : nullptr;
827-
(void)ggml_quantize_chunk(quant_type, f32_sample.data() + f_offset, quantized_buffer.data() + q_offset, 0, rs, n_per_row, value);
828-
q_offset += row_sz * (size_t)rs;
829-
f_offset += (size_t)rs * (size_t)n_per_row;
826+
const float * v = has_values ? values_sample + s * n_per_row : nullptr;
827+
(void)ggml_quantize_chunk(quant_type, f32_sample.data() + foff, quantized_buffer.data() + qoff, 0, rs, n_per_row, v);
828+
qoff += row_sz * (size_t)rs;
829+
foff += (size_t)rs * (size_t)n_per_row;
830830
}
831831
}
832832

833-
// quantized_buffer -> dequantized_buffer
833+
// Dequantize into dequantized_buffer
834834
{
835835
const ggml_type_traits * traits = ggml_get_type_traits(quant_type);
836-
const bool is_fp16 = quant_type == GGML_TYPE_F16;
837-
const bool is_bf16 = quant_type == GGML_TYPE_BF16;
838-
if (!is_fp16 && !is_bf16 && traits && traits->to_float) {
839-
traits->to_float(quantized_buffer.data(), dequantized_buffer.data(), (int)(sample_row_count * (size_t)n_per_row));
836+
if (traits && traits->to_float && quant_type != GGML_TYPE_F16 && quant_type != GGML_TYPE_BF16) {
837+
traits->to_float(quantized_buffer.data(), dequantized_buffer.data(), (int)(sample_rows * (size_t)n_per_row));
840838
} else {
841-
for (size_t r = 0; r < sample_row_count; ++r) {
842-
uint8_t * src = quantized_buffer.data() + r * row_sz;
839+
for (size_t r = 0; r < sample_rows; ++r) {
840+
const uint8_t * src = quantized_buffer.data() + r * row_sz;
843841
float * dst = dequantized_buffer.data() + r * (size_t)n_per_row;
844-
if (is_fp16) {
842+
if (quant_type == GGML_TYPE_F16) {
845843
ggml_fp16_to_fp32_row((const ggml_fp16_t *)src, dst, (int)n_per_row);
846-
} else if (is_bf16) {
844+
} else if (quant_type == GGML_TYPE_BF16) {
847845
ggml_bf16_to_fp32_row((const ggml_bf16_t *)src, dst, (int)n_per_row);
848846
} else {
849847
if (!traits || !traits->to_float) {
850848
if (out_mse) { *out_mse = infinity; }
851849
if (out_proj) { *out_proj = 0.0; }
852-
853850
return infinity;
854851
}
855852
traits->to_float(src, dst, (int)n_per_row);
@@ -858,94 +855,77 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
858855
}
859856
}
860857

861-
// Compute error
862-
size_t offset = 0;
863-
size_t row_idx = 0;
858+
// Compute error per slice with trimmed aggregation
859+
auto trimmed_sum = [&](std::vector<double> & v) -> double {
860+
const int64_t n = (int64_t)v.size();
861+
if (n == 0) { return 0.0; }
862+
if (n < 50) { return std::accumulate(v.begin(), v.end(), 0.0); }
863+
int64_t k = (int64_t) std::floor(0.02 * (double) n); // trim 2% on each side
864+
k = std::clamp<int64_t>(k, 0, n / 32); // but no more than ~3%
865+
std::nth_element(v.begin(), v.begin() + k, v.end());
866+
std::nth_element(v.begin() + k, v.begin() + (n - k), v.end());
867+
return std::accumulate(v.begin() + k, v.begin() + (n - k), 0.0);
868+
};
869+
870+
size_t off = 0;
871+
size_t ridx = 0;
864872
double total_mse = 0.0;
865873
double total_proj = 0.0;
866874
double total_bias = 0.0;
867-
for (int64_t slice = 0; slice < ne2; ++slice) {
868-
const int64_t rs = rows_sample[slice];
875+
for (int64_t s = 0; s < ne2; ++s) {
876+
const int64_t rs = rows_sample[s];
869877
if (rs == 0) { continue; }
870878

871-
const float * values = has_values ? values_sample + slice * n_per_row : nullptr;
872-
const float * activations = has_activations ? activations_sample + slice * n_per_row : nullptr;
873-
const double bias_denom = has_activations ? bias_denominator_per_slice[slice] : 0.0;
879+
const float * v = has_values ? values_sample + s * n_per_row : nullptr;
880+
const float * a = has_activations ? activations_sample + s * n_per_row : nullptr;
881+
const double denom_bias = has_activations ? bias_denom[s] : 0.0;
874882
std::vector<double> row_mse_norm;
875-
std::vector<double> row_proj_norm;
876883
row_mse_norm.reserve(rs);
877-
if (activations) { row_proj_norm.reserve(rs); }
884+
std::vector<double> row_proj_norm;
885+
if (a) { row_proj_norm.reserve(rs); }
878886

879-
for (int64_t r = 0; r < rs; ++r, ++row_idx) {
880-
const float * x = f32_sample.data() + offset;
881-
const float * y = dequantized_buffer.data() + offset;
882-
double weighted_mse = 0.0;
887+
for (int64_t r = 0; r < rs; ++r, ++ridx) {
888+
const float * x = f32_sample.data() + off;
889+
const float * y = dequantized_buffer.data() + off;
890+
double w_mse = 0.0;
883891
double bias_num = 0.0;
884-
if (values && activations) {
885-
for (int64_t j = 0; j < n_per_row; ++j) {
886-
const double w = std::max(0.0f, values[j]);
887-
const double e = y[j] - x[j];
888-
const double a = activations[j];
889-
weighted_mse += w * e * e;
890-
bias_num += w * e * a;
891-
}
892-
} else if (values) {
893-
for (int64_t j = 0; j < n_per_row; ++j) {
894-
const double w = std::max(0.0f, values[j]);
895-
const double e = y[j] - x[j];
896-
weighted_mse += w * e * e;
897-
}
898-
} else {
899-
for (int64_t j = 0; j < n_per_row; ++j) {
900-
const double e = y[j] - x[j];
901-
weighted_mse += e * e;
902-
}
892+
for (int64_t j = 0; j < n_per_row; ++j) {
893+
const double wj = v ? std::max(0.0f, v[j]) : 1.0;
894+
const double e = y[j] - x[j];
895+
w_mse += wj * e * e;
896+
if (a) { bias_num += wj * e * a[j]; }
903897
}
904898

905-
const double denom_x = row_sq_norm[row_idx];
906-
double m_norm = weighted_mse / (denom_x + epsilon);
899+
const double denom_x = row_sq_norm[ridx];
900+
const double m_norm = w_mse / (denom_x + epsilon);
907901
row_mse_norm.push_back(std::isfinite(m_norm) ? m_norm : infinity);
908902

909-
if (activations) {
903+
if (a) {
910904
double p_norm = 0.0;
911-
if (bias_denom > 0.0) {
912-
const double proj = bias_num * bias_num / (bias_denom + epsilon);
905+
if (denom_bias > 0.0) {
906+
const double proj = bias_num * bias_num / (denom_bias + epsilon);
913907
p_norm = std::isfinite(proj) ? proj : 0.0;
914908
}
909+
915910
row_proj_norm.push_back(p_norm);
916911
}
917912

918-
offset += (size_t)n_per_row;
913+
off += (size_t)n_per_row;
919914
}
920915

921-
// Trimmed sum to avoid outlier rows dominating the results
922-
auto trimmed_sum = [&](std::vector<double> & v) -> double {
923-
const int64_t n = (int64_t)v.size();
924-
if (n == 0) { return 0.0; }
925-
if (n < 50) { return std::accumulate(v.begin(), v.end(), 0.0); }
926-
927-
int64_t k = (int64_t)std::floor(0.02 * (double)n); // trim 2% each side
928-
k = std::clamp<int64_t>(k, 0, n / 32); // cap at ~3.125%
929-
std::nth_element(v.begin(), v.begin() + k, v.end());
930-
std::nth_element(v.begin() + k, v.begin() + (n - k), v.end());
931-
return std::accumulate(v.begin() + k, v.begin() + (n - k), 0.0);
932-
};
933-
934916
const double scale_rows = (double)nrows / std::max(1.0, (double)rs);
935917
const double slice_mse = trimmed_sum(row_mse_norm) * scale_rows;
936-
const double slice_proj = activations ? trimmed_sum(row_proj_norm) * scale_rows : 0.0;
918+
const double slice_proj = a ? trimmed_sum(row_proj_norm) * scale_rows : 0.0;
937919

938920
total_mse += slice_mse;
939921
total_proj += slice_proj;
940922

941-
// per-slice lambda if provided, otherwise use scalar
942-
const double bl = slice_bias_lambda ? (double)std::max(0.0f, slice_bias_lambda[slice]) : (double)tensor_bias_lambda;
923+
const double bl = slice_bias_lambda ? (double)std::max(0.0f, slice_bias_lambda[s]) : (double)tensor_bias_lambda;
943924
total_bias += bl * slice_proj;
944925

945926
if (!std::isfinite(total_mse) || !std::isfinite(total_proj) || !std::isfinite(total_bias)) {
946927
if (out_mse) { *out_mse = infinity; }
947928
if (out_proj) { *out_proj = 0.0; }
948-
949929
return infinity;
950930
}
951931
}
@@ -954,7 +934,6 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
954934
if (out_proj) { *out_proj = total_proj; }
955935

956936
const double total_err = slice_bias_lambda ? total_mse + total_bias : total_mse + tensor_bias_lambda * total_proj;
957-
958937
return std::isfinite(total_err) ? total_err : infinity;
959938
};
960939

0 commit comments

Comments
 (0)