@@ -737,12 +737,12 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
737737 const int64_t n_per_row = t->ne [0 ];
738738 const int64_t nrows = t->ne [1 ];
739739 const int64_t ne2 = t->ne [2 ] > 0 ? t->ne [2 ] : 1 ;
740- const size_t sample_element_count = f32_sample.size ();
741- const size_t sample_row_count = n_per_row > 0 ? sample_element_count / (size_t )n_per_row : 0 ;
742- if (sample_row_count == 0 ) {
740+ const size_t sample_elems = f32_sample.size ();
741+ const size_t sample_rows = n_per_row > 0 ? sample_elems / (size_t )n_per_row : 0 ;
742+
743+ if (sample_rows == 0 ) {
743744 if (out_mse) { *out_mse = 0.0 ; }
744745 if (out_proj) { *out_proj = 0.0 ; }
745-
746746 return 0.0 ;
747747 }
748748
@@ -751,105 +751,102 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
751751 expected_rows += (size_t )rows_sample[s];
752752 }
753753
754- if (expected_rows != sample_row_count ) {
754+ if (expected_rows != sample_rows ) {
755755 if (out_mse) { *out_mse = infinity; }
756756 if (out_proj) { *out_proj = 0.0 ; }
757-
758757 return infinity;
759758 }
760759
761760 const size_t row_sz = ggml_row_size (quant_type, n_per_row);
762- const size_t buffer_sz = row_sz * sample_row_count ;
761+ const size_t buf_sz = row_sz * sample_rows ;
763762
764- if (quantized_buffer.size () < buffer_sz ) { quantized_buffer.resize (buffer_sz ); }
765- if (dequantized_buffer.size () < sample_element_count ) { dequantized_buffer.resize (sample_element_count ); }
763+ if (quantized_buffer.size () < buf_sz ) { quantized_buffer.resize (buf_sz ); }
764+ if (dequantized_buffer.size () < sample_elems ) { dequantized_buffer.resize (sample_elems ); }
766765
767766 const bool has_values = values_sample != nullptr ;
768767 const bool has_activations = activations_sample != nullptr ;
769768
770769 // Bias denominators per slice
771- std::vector<double > bias_denominator_per_slice (ne2, 0.0 );
770+ std::vector<double > bias_denom (ne2, 0.0 );
772771 if (has_activations) {
773772 for (int64_t s = 0 ; s < ne2; ++s) {
774- const float * values = has_values ? values_sample + s * n_per_row : nullptr ;
775- const float * activations = activations_sample + s * n_per_row;
773+ const float * v = has_values ? values_sample + s * n_per_row : nullptr ;
774+ const float * a = activations_sample + s * n_per_row;
776775 double denom = 0.0 ;
777776 for (int64_t j = 0 ; j < n_per_row; ++j) {
778- const double w = values ? std::max (0 .0f , values [j]) : 1.0 ;
779- const double a = activations [j];
780- denom += w * a * a ;
777+ const double w = v ? std::max (0 .0f , v [j]) : 1.0 ;
778+ const double aj = a [j];
779+ denom += w * aj * aj ;
781780 }
782781
783- bias_denominator_per_slice [s] = denom;
782+ bias_denom [s] = denom;
784783 }
785784 }
786785
787- // Weighted per-row squared norms
788- std::vector<double > row_sq_norm (sample_row_count , 0.0 );
786+ // Row squared norms (weighted if values present)
787+ std::vector<double > row_sq_norm (sample_rows , 0.0 );
789788 {
790- size_t offset = 0 ;
791- size_t row_idx = 0 ;
789+ size_t off = 0 ;
790+ size_t ridx = 0 ;
792791 for (int64_t s = 0 ; s < ne2; ++s) {
793792 const int64_t rs = rows_sample[s];
794793 if (rs == 0 ) { continue ; }
795794
796- const float * values = has_values ? values_sample + s * n_per_row : nullptr ;
797- for (int64_t r = 0 ; r < rs; ++r, ++row_idx ) {
798- const float * x = f32_sample.data () + offset ;
799- double rsn = 0.0 ;
800- if (values ) {
795+ const float * v = has_values ? values_sample + s * n_per_row : nullptr ;
796+ for (int64_t r = 0 ; r < rs; ++r, ++ridx ) {
797+ const float * x = f32_sample.data () + off ;
798+ double sum = 0.0 ;
799+ if (v ) {
801800 for (int64_t j = 0 ; j < n_per_row; ++j) {
802- const double w = std::max (0 .0f , values [j]);
801+ const double w = std::max (0 .0f , v [j]);
803802 const double xx = x[j];
804- rsn += w * xx * xx;
803+ sum += w * xx * xx;
805804 }
806805 } else {
807806 for (int64_t j = 0 ; j < n_per_row; ++j) {
808807 const double xx = x[j];
809- rsn += xx * xx;
808+ sum += xx * xx;
810809 }
811810 }
812- row_sq_norm[row_idx] = rsn;
813- offset += (size_t )n_per_row;
811+
812+ row_sq_norm[ridx] = sum;
813+ off += (size_t )n_per_row;
814814 }
815815 }
816816 }
817817
818- // Quantize sampled rows per slice -> quantized_buffer
818+ // Quantize per slice into quantized_buffer
819819 {
820- size_t q_offset = 0 ;
821- size_t f_offset = 0 ;
822- for (int64_t slice = 0 ; slice < ne2; ++slice ) {
823- const int64_t rs = rows_sample[slice ];
820+ size_t qoff = 0 ;
821+ size_t foff = 0 ;
822+ for (int64_t s = 0 ; s < ne2; ++s ) {
823+ const int64_t rs = rows_sample[s ];
824824 if (rs == 0 ) { continue ; }
825825
826- const float * value = has_values ? values_sample + slice * n_per_row : nullptr ;
827- (void )ggml_quantize_chunk (quant_type, f32_sample.data () + f_offset , quantized_buffer.data () + q_offset , 0 , rs, n_per_row, value );
828- q_offset += row_sz * (size_t )rs;
829- f_offset += (size_t )rs * (size_t )n_per_row;
826+ const float * v = has_values ? values_sample + s * n_per_row : nullptr ;
827+ (void )ggml_quantize_chunk (quant_type, f32_sample.data () + foff , quantized_buffer.data () + qoff , 0 , rs, n_per_row, v );
828+ qoff += row_sz * (size_t )rs;
829+ foff += (size_t )rs * (size_t )n_per_row;
830830 }
831831 }
832832
833- // quantized_buffer -> dequantized_buffer
833+ // Dequantize into dequantized_buffer
834834 {
835835 const ggml_type_traits * traits = ggml_get_type_traits (quant_type);
836- const bool is_fp16 = quant_type == GGML_TYPE_F16;
837- const bool is_bf16 = quant_type == GGML_TYPE_BF16;
838- if (!is_fp16 && !is_bf16 && traits && traits->to_float ) {
839- traits->to_float (quantized_buffer.data (), dequantized_buffer.data (), (int )(sample_row_count * (size_t )n_per_row));
836+ if (traits && traits->to_float && quant_type != GGML_TYPE_F16 && quant_type != GGML_TYPE_BF16) {
837+ traits->to_float (quantized_buffer.data (), dequantized_buffer.data (), (int )(sample_rows * (size_t )n_per_row));
840838 } else {
841- for (size_t r = 0 ; r < sample_row_count ; ++r) {
842- uint8_t * src = quantized_buffer.data () + r * row_sz;
839+ for (size_t r = 0 ; r < sample_rows ; ++r) {
840+ const uint8_t * src = quantized_buffer.data () + r * row_sz;
843841 float * dst = dequantized_buffer.data () + r * (size_t )n_per_row;
844- if (is_fp16 ) {
842+ if (quant_type == GGML_TYPE_F16 ) {
845843 ggml_fp16_to_fp32_row ((const ggml_fp16_t *)src, dst, (int )n_per_row);
846- } else if (is_bf16 ) {
844+ } else if (quant_type == GGML_TYPE_BF16 ) {
847845 ggml_bf16_to_fp32_row ((const ggml_bf16_t *)src, dst, (int )n_per_row);
848846 } else {
849847 if (!traits || !traits->to_float ) {
850848 if (out_mse) { *out_mse = infinity; }
851849 if (out_proj) { *out_proj = 0.0 ; }
852-
853850 return infinity;
854851 }
855852 traits->to_float (src, dst, (int )n_per_row);
@@ -858,94 +855,77 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
858855 }
859856 }
860857
861- // Compute error
862- size_t offset = 0 ;
863- size_t row_idx = 0 ;
858+ // Compute error per slice with trimmed aggregation
859+ auto trimmed_sum = [&](std::vector<double > & v) -> double {
860+ const int64_t n = (int64_t )v.size ();
861+ if (n == 0 ) { return 0.0 ; }
862+ if (n < 50 ) { return std::accumulate (v.begin (), v.end (), 0.0 ); }
863+ int64_t k = (int64_t ) std::floor (0.02 * (double ) n); // trim 2% on each side
864+ k = std::clamp<int64_t >(k, 0 , n / 32 ); // but no more than ~3%
865+ std::nth_element (v.begin (), v.begin () + k, v.end ());
866+ std::nth_element (v.begin () + k, v.begin () + (n - k), v.end ());
867+ return std::accumulate (v.begin () + k, v.begin () + (n - k), 0.0 );
868+ };
869+
870+ size_t off = 0 ;
871+ size_t ridx = 0 ;
864872 double total_mse = 0.0 ;
865873 double total_proj = 0.0 ;
866874 double total_bias = 0.0 ;
867- for (int64_t slice = 0 ; slice < ne2; ++slice ) {
868- const int64_t rs = rows_sample[slice ];
875+ for (int64_t s = 0 ; s < ne2; ++s ) {
876+ const int64_t rs = rows_sample[s ];
869877 if (rs == 0 ) { continue ; }
870878
871- const float * values = has_values ? values_sample + slice * n_per_row : nullptr ;
872- const float * activations = has_activations ? activations_sample + slice * n_per_row : nullptr ;
873- const double bias_denom = has_activations ? bias_denominator_per_slice[slice ] : 0.0 ;
879+ const float * v = has_values ? values_sample + s * n_per_row : nullptr ;
880+ const float * a = has_activations ? activations_sample + s * n_per_row : nullptr ;
881+ const double denom_bias = has_activations ? bias_denom[s ] : 0.0 ;
874882 std::vector<double > row_mse_norm;
875- std::vector<double > row_proj_norm;
876883 row_mse_norm.reserve (rs);
877- if (activations) { row_proj_norm.reserve (rs); }
884+ std::vector<double > row_proj_norm;
885+ if (a) { row_proj_norm.reserve (rs); }
878886
879- for (int64_t r = 0 ; r < rs; ++r, ++row_idx ) {
880- const float * x = f32_sample.data () + offset ;
881- const float * y = dequantized_buffer.data () + offset ;
882- double weighted_mse = 0.0 ;
887+ for (int64_t r = 0 ; r < rs; ++r, ++ridx ) {
888+ const float * x = f32_sample.data () + off ;
889+ const float * y = dequantized_buffer.data () + off ;
890+ double w_mse = 0.0 ;
883891 double bias_num = 0.0 ;
884- if (values && activations) {
885- for (int64_t j = 0 ; j < n_per_row; ++j) {
886- const double w = std::max (0 .0f , values[j]);
887- const double e = y[j] - x[j];
888- const double a = activations[j];
889- weighted_mse += w * e * e;
890- bias_num += w * e * a;
891- }
892- } else if (values) {
893- for (int64_t j = 0 ; j < n_per_row; ++j) {
894- const double w = std::max (0 .0f , values[j]);
895- const double e = y[j] - x[j];
896- weighted_mse += w * e * e;
897- }
898- } else {
899- for (int64_t j = 0 ; j < n_per_row; ++j) {
900- const double e = y[j] - x[j];
901- weighted_mse += e * e;
902- }
892+ for (int64_t j = 0 ; j < n_per_row; ++j) {
893+ const double wj = v ? std::max (0 .0f , v[j]) : 1.0 ;
894+ const double e = y[j] - x[j];
895+ w_mse += wj * e * e;
896+ if (a) { bias_num += wj * e * a[j]; }
903897 }
904898
905- const double denom_x = row_sq_norm[row_idx ];
906- double m_norm = weighted_mse / (denom_x + epsilon);
899+ const double denom_x = row_sq_norm[ridx ];
900+ const double m_norm = w_mse / (denom_x + epsilon);
907901 row_mse_norm.push_back (std::isfinite (m_norm) ? m_norm : infinity);
908902
909- if (activations ) {
903+ if (a ) {
910904 double p_norm = 0.0 ;
911- if (bias_denom > 0.0 ) {
912- const double proj = bias_num * bias_num / (bias_denom + epsilon);
905+ if (denom_bias > 0.0 ) {
906+ const double proj = bias_num * bias_num / (denom_bias + epsilon);
913907 p_norm = std::isfinite (proj) ? proj : 0.0 ;
914908 }
909+
915910 row_proj_norm.push_back (p_norm);
916911 }
917912
918- offset += (size_t )n_per_row;
913+ off += (size_t )n_per_row;
919914 }
920915
921- // Trimmed sum to avoid outlier rows dominating the results
922- auto trimmed_sum = [&](std::vector<double > & v) -> double {
923- const int64_t n = (int64_t )v.size ();
924- if (n == 0 ) { return 0.0 ; }
925- if (n < 50 ) { return std::accumulate (v.begin (), v.end (), 0.0 ); }
926-
927- int64_t k = (int64_t )std::floor (0.02 * (double )n); // trim 2% each side
928- k = std::clamp<int64_t >(k, 0 , n / 32 ); // cap at ~3.125%
929- std::nth_element (v.begin (), v.begin () + k, v.end ());
930- std::nth_element (v.begin () + k, v.begin () + (n - k), v.end ());
931- return std::accumulate (v.begin () + k, v.begin () + (n - k), 0.0 );
932- };
933-
934916 const double scale_rows = (double )nrows / std::max (1.0 , (double )rs);
935917 const double slice_mse = trimmed_sum (row_mse_norm) * scale_rows;
936- const double slice_proj = activations ? trimmed_sum (row_proj_norm) * scale_rows : 0.0 ;
918+ const double slice_proj = a ? trimmed_sum (row_proj_norm) * scale_rows : 0.0 ;
937919
938920 total_mse += slice_mse;
939921 total_proj += slice_proj;
940922
941- // per-slice lambda if provided, otherwise use scalar
942- const double bl = slice_bias_lambda ? (double )std::max (0 .0f , slice_bias_lambda[slice]) : (double )tensor_bias_lambda;
923+ const double bl = slice_bias_lambda ? (double )std::max (0 .0f , slice_bias_lambda[s]) : (double )tensor_bias_lambda;
943924 total_bias += bl * slice_proj;
944925
945926 if (!std::isfinite (total_mse) || !std::isfinite (total_proj) || !std::isfinite (total_bias)) {
946927 if (out_mse) { *out_mse = infinity; }
947928 if (out_proj) { *out_proj = 0.0 ; }
948-
949929 return infinity;
950930 }
951931 }
@@ -954,7 +934,6 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
954934 if (out_proj) { *out_proj = total_proj; }
955935
956936 const double total_err = slice_bias_lambda ? total_mse + total_bias : total_mse + tensor_bias_lambda * total_proj;
957-
958937 return std::isfinite (total_err) ? total_err : infinity;
959938 };
960939
0 commit comments