@@ -1021,27 +1021,38 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
10211021 };
10221022
10231023 // Faster to compute but may yield lower precision. Best option for the vast majority of cases
1024- auto fast_lambda = [&](const float * values, const float * activations, const int64_t n_per_row) {
1024+ auto fast_lambda = [&](const float * values, const float * activations, const int64_t n_per_row, const int64_t ne2 ) {
10251025 if (!activations) { return 0 .0f ; }
10261026
1027- double s = 0.0 ;
1028- double s2 = 0.0 ;
1029- for (int64_t j = 0 ; j < n_per_row; ++j) {
1030- const double w = values ? std::max (0 .0f , values[j]) : 1.0 ;
1031- const double aw = std::sqrt (w) * activations[j];
1032- const double aw2 = aw * aw;
1033- s += aw2;
1034- s2 += aw2 * aw2;
1035- }
1027+ double accum = 0.0 ;
1028+ int ns = 0 ;
1029+
1030+ for (int64_t s = 0 ; s < std::max<int64_t >(1 , ne2); ++s) {
1031+ const float * v = values ? values + s * n_per_row : nullptr ;
1032+ const float * a = activations + s * n_per_row;
1033+
1034+ double s1 = 0.0 ;
1035+ double s2 = 0.0 ;
1036+ for (int64_t j = 0 ; j < n_per_row; ++j) {
1037+ const double w = v ? std::max (0 .0f , v[j]) : 1.0 ;
1038+ const double aw = std::sqrt (w) * a[j];
1039+ const double aw2 = aw * aw;
1040+ s1 += aw2;
1041+ s2 += aw2 * aw2;
1042+ }
10361043
1037- if (s2 <= 0.0 ) { return 0 .0f ; }
1038- const auto d = (double )n_per_row;
1039- double base = 1.0 - s * s / (d * s2 + epsilon);
1040- base = std::clamp (base, 0.0 , 1.0 );
1044+ if (s1 > 0.0 ) {
1045+ const double n = (double )n_per_row;
1046+ double c = std::max (0.0 , s2 / (s1 * s1 + epsilon) - 1.0 / n);
1047+ double lambda = 8.0 * (c / (c + 1.0 ));
1048+ accum += std::clamp (lambda, 0.0 , 8.0 );
1049+ ++ns;
1050+ }
1051+ }
10411052
1042- const double lambda = std::clamp (base, 0.0 , 1.0 ) * 8.0 ;
1053+ if (ns == 0 ) { return 0 . 0f ; }
10431054
1044- return (float )lambda ;
1055+ return (float )(accum / ns) ;
10451056 };
10461057
10471058 std::vector<tensor_info> all;
@@ -1190,7 +1201,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
11901201 const float * values = values_sample.empty () ? nullptr : values_sample.data ();
11911202 const float * activations = activations_sample.empty () ? nullptr : activations_sample.data ();
11921203 if (params->bpw_bias == 1 ) {
1193- bias_lambda = fast_lambda (values, activations, n_per_row);
1204+ bias_lambda = fast_lambda (values, activations, n_per_row, ne2 );
11941205 } else if (params->bpw_bias == 2 ) {
11951206 bias_lambda = precise_lambda (t, f32_sample, sample_rows_per_slice, values, activations, compatible_candidates);
11961207 }
0 commit comments