Skip to content

Commit c709e1a

Browse files
committed
Fix MoE tensor estimation
1 parent 8503d59 commit c709e1a

File tree

1 file changed

+28
-17
lines changed

1 file changed

+28
-17
lines changed

src/llama-quant.cpp

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,27 +1021,38 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
10211021
};
10221022

10231023
// Faster to compute but may yield lower precision. Best option for the vast majority of cases
1024-
auto fast_lambda = [&](const float * values, const float * activations, const int64_t n_per_row) {
1024+
auto fast_lambda = [&](const float * values, const float * activations, const int64_t n_per_row, const int64_t ne2) {
10251025
if (!activations) { return 0.0f; }
10261026

1027-
double s = 0.0;
1028-
double s2 = 0.0;
1029-
for (int64_t j = 0; j < n_per_row; ++j) {
1030-
const double w = values ? std::max(0.0f, values[j]) : 1.0;
1031-
const double aw = std::sqrt(w) * activations[j];
1032-
const double aw2 = aw * aw;
1033-
s += aw2;
1034-
s2 += aw2 * aw2;
1035-
}
1027+
double accum = 0.0;
1028+
int ns = 0;
1029+
1030+
for (int64_t s = 0; s < std::max<int64_t>(1, ne2); ++s) {
1031+
const float * v = values ? values + s * n_per_row : nullptr;
1032+
const float * a = activations + s * n_per_row;
1033+
1034+
double s1 = 0.0;
1035+
double s2 = 0.0;
1036+
for (int64_t j = 0; j < n_per_row; ++j) {
1037+
const double w = v ? std::max(0.0f, v[j]) : 1.0;
1038+
const double aw = std::sqrt(w) * a[j];
1039+
const double aw2 = aw * aw;
1040+
s1 += aw2;
1041+
s2 += aw2 * aw2;
1042+
}
10361043

1037-
if (s2 <= 0.0) { return 0.0f; }
1038-
const auto d = (double)n_per_row;
1039-
double base = 1.0 - s * s / (d * s2 + epsilon);
1040-
base = std::clamp(base, 0.0, 1.0);
1044+
if (s1 > 0.0) {
1045+
const double n = (double)n_per_row;
1046+
double c = std::max(0.0, s2 / (s1 * s1 + epsilon) - 1.0 / n);
1047+
double lambda = 8.0 * (c / (c + 1.0));
1048+
accum += std::clamp(lambda, 0.0, 8.0);
1049+
++ns;
1050+
}
1051+
}
10411052

1042-
const double lambda = std::clamp(base, 0.0, 1.0) * 8.0;
1053+
if (ns == 0) { return 0.0f; }
10431054

1044-
return (float)lambda;
1055+
return (float)(accum / ns);
10451056
};
10461057

10471058
std::vector<tensor_info> all;
@@ -1190,7 +1201,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
11901201
const float * values = values_sample.empty() ? nullptr : values_sample.data();
11911202
const float * activations = activations_sample.empty() ? nullptr : activations_sample.data();
11921203
if (params->bpw_bias == 1) {
1193-
bias_lambda = fast_lambda(values, activations, n_per_row);
1204+
bias_lambda = fast_lambda(values, activations, n_per_row, ne2);
11941205
} else if (params->bpw_bias == 2) {
11951206
bias_lambda = precise_lambda(t, f32_sample, sample_rows_per_slice, values, activations, compatible_candidates);
11961207
}

0 commit comments

Comments
 (0)