@@ -221,7 +221,8 @@ static int load_legacy_imatrix(const std::string & imatrix_file, std::vector<std
221221static int load_imatrix (const std::string & imatrix_file,
222222 std::vector<std::string> & imatrix_datasets,
223223 std::unordered_map<std::string, std::vector<float >> & values_data,
224- std::unordered_map<std::string, std::vector<float >> & activations_data) {
224+ std::unordered_map<std::string, std::vector<float >> & activations_data,
225+ std::unordered_map<std::string, std::vector<float >> & statistics_data) {
225226
226227 struct ggml_context * ctx = nullptr ;
227228 struct gguf_init_params meta_gguf_params = {
@@ -256,24 +257,28 @@ static int load_imatrix(const std::string & imatrix_file,
256257 const std::string sums_suffix{ " .in_sum" };
257258 const std::string sums2_suffix{ " .in_sum2" };
258259 const std::string counts_suffix{ " .counts" };
260+ const std::string stats_suffix{ " .stats" };
259261
260262 // Using an ordered map to get a deterministic iteration order.
261- std::map<std::string, std::tuple<struct ggml_tensor *, struct ggml_tensor *, struct ggml_tensor *>> sums_counts_for;
263+ std::map<std::string, std::tuple<struct ggml_tensor *, struct ggml_tensor *, struct ggml_tensor *, struct ggml_tensor * >> sums_counts_for;
262264
263265 for (struct ggml_tensor * cur = ggml_get_first_tensor (ctx); cur; cur = ggml_get_next_tensor (ctx, cur)) {
264266 std::string name = cur->name ;
265267
266268 if (name.empty ()) { continue ; }
267269
268- if (string_remove_suffix (name, sums2_suffix )) {
269- // in_sum2
270+ if (string_remove_suffix (name, sums_suffix )) {
271+ // in_sum
270272 std::get<0 >(sums_counts_for[std::move (name)]) = cur;
273+ } else if (string_remove_suffix (name, sums2_suffix)) {
274+ // in_sum2
275+ std::get<1 >(sums_counts_for[std::move (name)]) = cur;
271276 } else if (string_remove_suffix (name, counts_suffix)) {
272277 // counts
273- std::get<1 >(sums_counts_for[std::move (name)]) = cur;
274- } else if (string_remove_suffix (name, sums_suffix)) {
275- // in_sum
276278 std::get<2 >(sums_counts_for[std::move (name)]) = cur;
279+ } else if (string_remove_suffix (name, stats_suffix)) {
280+ // stats
281+ std::get<3 >(sums_counts_for[std::move (name)]) = cur;
277282 }
278283 else {
279284 // ignore other tensors
@@ -282,11 +287,12 @@ static int load_imatrix(const std::string & imatrix_file,
282287
283288 for (const auto & sc : sums_counts_for) {
284289 const std::string & name = sc.first ;
285- const struct ggml_tensor * sums = std::get<2 >(sc.second );
286- const struct ggml_tensor * sums2 = std::get<0 >(sc.second );
287- const struct ggml_tensor * counts = std::get<1 >(sc.second );
290+ const struct ggml_tensor * sums = std::get<0 >(sc.second );
291+ const struct ggml_tensor * sums2 = std::get<1 >(sc.second );
292+ const struct ggml_tensor * counts = std::get<2 >(sc.second );
293+ const struct ggml_tensor * stats = std::get<3 >(sc.second );
288294
289- // check that sums, sums2 and counts have the same shape
295+ // check sums2 and counts are present, and that sums and sums2 have the same shape
290296 if (!sums2 || !counts || (sums != nullptr && ggml_nelements (sums) != ggml_nelements (sums2))) {
291297 fprintf (stderr, " %s: mismatched sums and counts for %s\n " , __func__, name.c_str ());
292298 gguf_free (ctx_gguf);
@@ -302,6 +308,19 @@ static int load_imatrix(const std::string & imatrix_file,
302308 if (sums) {
303309 activations.resize (ggml_nelements (sums));
304310 }
311+ if (stats) {
312+ auto & statistics = statistics_data[name];
313+ statistics.resize (ggml_nelements (stats));
314+ if (stats->type == GGML_TYPE_F32) {
315+ std::memcpy (statistics.data (), stats->data , ggml_nelements (stats) * sizeof (float ));
316+ } else {
317+ fprintf (stderr, " %s: unsupported .stats type '%s' for '%s' - ignoring entry\n " ,
318+ __func__, ggml_type_name (stats->type ), name.c_str ());
319+ statistics.clear ();
320+ statistics_data.erase (name);
321+ }
322+
323+ }
305324 values.resize (ggml_nelements (sums2));
306325 float max_count = 0 .0f ;
307326 for (int64_t j = 0 ; j < ne1; ++j) {
@@ -354,10 +373,11 @@ static int prepare_imatrix(const std::string & imatrix_file,
354373 const std::vector<std::string> & included_weights,
355374 const std::vector<std::string> & excluded_weights,
356375 std::unordered_map<std::string, std::vector<float >> & values_data,
357- std::unordered_map<std::string, std::vector<float >> & activations_data) {
376+ std::unordered_map<std::string, std::vector<float >> & activations_data,
377+ std::unordered_map<std::string, std::vector<float >> & statistics_data) {
358378 int m_last_call = -1 ;
359379 if (!imatrix_file.empty ()) {
360- m_last_call = load_imatrix (imatrix_file, imatrix_dataset, values_data, activations_data);
380+ m_last_call = load_imatrix (imatrix_file, imatrix_dataset, values_data, activations_data, statistics_data );
361381 }
362382 if (values_data.empty ()) {
363383 return m_last_call;
@@ -380,11 +400,20 @@ static int prepare_imatrix(const std::string & imatrix_file,
380400 ++at;
381401 }
382402 }
403+ for (auto st = statistics_data.begin (); st != statistics_data.end ();) {
404+ auto pos = st->first .find (name);
405+ if (pos != std::string::npos) {
406+ st = activations_data.erase (st);
407+ } else {
408+ ++st;
409+ }
410+ }
383411 }
384412 }
385413 if (!included_weights.empty ()) {
386414 std::unordered_map<std::string, std::vector<float >> tmp_values;
387415 std::unordered_map<std::string, std::vector<float >> tmp_activations;
416+ std::unordered_map<std::string, std::vector<float >> tmp_statistics;
388417 for (const auto & name : included_weights) {
389418 for (auto & e : values_data) {
390419 auto pos = e.first .find (name);
@@ -398,9 +427,16 @@ static int prepare_imatrix(const std::string & imatrix_file,
398427 tmp_activations.emplace (std::move (a));
399428 }
400429 }
430+ for (auto & s : statistics_data) {
431+ auto pos = s.first .find (name);
432+ if (pos != std::string::npos) {
433+ tmp_statistics.emplace (std::move (s));
434+ }
435+ }
401436 }
402437 values_data = std::move (tmp_values);
403438 activations_data = std::move (tmp_activations);
439+ statistics_data = std::move (tmp_statistics);
404440 }
405441
406442 return m_last_call;
@@ -617,7 +653,8 @@ int main(int argc, char ** argv) {
617653 std::vector<std::string> imatrix_datasets;
618654 std::unordered_map<std::string, std::vector<float >> values_data;
619655 std::unordered_map<std::string, std::vector<float >> activations_data;
620- int m_last_call = prepare_imatrix (imatrix_file, imatrix_datasets, included_weights, excluded_weights, values_data, activations_data);
656+ std::unordered_map<std::string, std::vector<float >> statistics_data;
657+ int m_last_call = prepare_imatrix (imatrix_file, imatrix_datasets, included_weights, excluded_weights, values_data, activations_data, statistics_data);
621658 if (!values_data.empty ()) {
622659 params.imatrix = &values_data;
623660 {
@@ -657,6 +694,9 @@ int main(int argc, char ** argv) {
657694 if (!activations_data.empty ()) {
658695 params.activations = &activations_data;
659696 }
697+ if (!statistics_data.empty ()) {
698+ params.statistics = &statistics_data;
699+ }
660700 if (!kv_overrides.empty ()) {
661701 kv_overrides.emplace_back ();
662702 kv_overrides.back ().key [0 ] = 0 ;
0 commit comments