99#include < mutex>
1010#include < vector>
1111#include < unordered_map>
12+ #include < map>
1213#include < algorithm>
1314
1415#if defined(_MSC_VER)
@@ -24,6 +25,14 @@ static void print_usage(int, char ** argv) {
2425 LOG_TEE (" \n " );
2526}
2627
28+ static bool str_remove_suffix (std::string & str, const std::string & suffix) {
29+ bool has_suffix = str.size () >= suffix.size () && str.compare (str.size () - suffix.size (), str.size (), suffix) == 0 ;
30+ if (has_suffix) {
31+ str = str.substr (0 , str.size () - suffix.size ());
32+ }
33+ return has_suffix;
34+ }
35+
2736static const char * const LLM_KV_IMATRIX_DATASET = " imatrix.dataset" ;
2837static const char * const LLM_KV_IMATRIX_CHUNK_COUNT = " imatrix.chunk_count" ;
2938static const char * const LLM_KV_IMATRIX_CHUNK_SIZE = " imatrix.chunk_size" ;
@@ -302,8 +311,8 @@ void IMatrixCollector::save_imatrix(int32_t n_chunk) const {
302311 if (nval > 0 ) {
303312 struct ggml_tensor * sums = ggml_new_tensor_2d (ctx, GGML_TYPE_F32, nval / nmat, nmat);
304313 struct ggml_tensor * counts = ggml_new_tensor_2d (ctx, GGML_TYPE_F32, 1 , nmat);
305- ggml_set_name (sums, (name + " .sums" ) .c_str ());
306- ggml_set_name (counts, (name + " .counts" ) .c_str ());
314+ ggml_format_name (sums, " %s .sums" , name .c_str ());
315+ ggml_format_name (counts, " %s .counts" , name .c_str ());
307316
308317 for (int32_t j = 0 ; j < nval; ++j) {
309318 ((float *) sums->data )[j] = (float ) stat.values [j];
@@ -338,7 +347,7 @@ bool IMatrixCollector::load_imatrix(const char * file_name) {
338347 return false ;
339348 }
340349 const int32_t n_entries = gguf_get_n_tensors (ctx_gguf);
341- if (n_entries < 2 ) {
350+ if (n_entries < 1 ) {
342351 fprintf (stderr, " %s: no data in file %s\n " , __func__, file_name);
343352 gguf_free (ctx_gguf);
344353 ggml_free (ctx);
@@ -348,51 +357,73 @@ bool IMatrixCollector::load_imatrix(const char * file_name) {
348357 const std::string sums_suffix{" .sums" };
349358 const std::string counts_suffix{" .counts" };
350359
351- // TODO: allow loading from mis-ordered imatrix files
352- for (int32_t i = 0 ; i < n_entries - 1 ; i += 2 ) {
353- std::string sums_name{gguf_get_tensor_name (ctx_gguf, i + 0 )};
354- std::string counts_name{gguf_get_tensor_name (ctx_gguf, i + 1 )};
355-
356- if (sums_name.size () < sums_suffix.size () ||
357- counts_name.size () < counts_suffix.size () ||
358- !std::equal (sums_name.begin (), sums_name.end () - sums_suffix.size (), counts_name.begin ()) ||
359- !std::equal (sums_suffix.rbegin (), sums_suffix.rend (), sums_name.rbegin ()) ||
360- !std::equal (counts_suffix.rbegin (), counts_suffix.rend (), counts_name.rbegin ())) {
361- fprintf (stderr, " %s: mismatched sums and counts for entry %d\n " , __func__, i / 2 );
360+ // Could re-use m_stats instead, but this allows
361+ // checking for completeness of *each* loaded imatrix file
362+ // and also makes it easier to re-use a similar implementation in quantize.cpp
363+ // Using an ordered map to get a deterministic iteration order.
364+ std::map<std::string, std::pair<struct ggml_tensor *, struct ggml_tensor *>> sums_counts_for;
365+
366+ for (struct ggml_tensor * cur = ggml_get_first_tensor (ctx); cur; cur = ggml_get_next_tensor (ctx, cur)) {
367+ std::string name = cur->name ;
368+
369+ if (name.empty ()) { continue ; }
370+
371+ if (str_remove_suffix (name, sums_suffix)) {
372+ // sums
373+ sums_counts_for[name].first = cur;
374+ } else if (str_remove_suffix (name, counts_suffix)) {
375+ // counts
376+ sums_counts_for[name].second = cur;
377+ } else {
378+ fprintf (stderr, " %s: invalid imatrix tensor name: %s\n " , __func__, name.c_str ());
362379 gguf_free (ctx_gguf);
363380 ggml_free (ctx);
364381 return false ;
365382 }
383+ }
384+
385+ for (const auto & sc : sums_counts_for) {
386+ const std::string & name = sc.first ;
387+ const struct ggml_tensor * sums = sc.second .first ;
388+ const struct ggml_tensor * counts = sc.second .second ;
366389
367- struct ggml_tensor * sums = ggml_get_tensor (ctx, sums_name.c_str ());
368- struct ggml_tensor * counts = ggml_get_tensor (ctx, counts_name.c_str ());
369390 if (!sums || !counts) {
370- fprintf (stderr, " %s: failed reading data for entry %d \n " , __func__, i / 2 );
391+ fprintf (stderr, " %s: mismatched sums and counts for %s \n " , __func__, name. c_str () );
371392 gguf_free (ctx_gguf);
372393 ggml_free (ctx);
373394 return false ;
374395 }
375396
376- std::string name = sums_name.substr (0 , sums_name.size () - sums_suffix.size ());
377397 auto & e = m_stats[name];
378398
379- int32_t nval = ggml_nelements (sums);
399+ int64_t nval = ggml_nelements (sums);
380400 if (e.values .empty ()) {
381401 e.values .resize (nval, 0 );
402+ } else if ((size_t ) nval != e.values .size ()) {
403+ fprintf (stderr, " %s: mismatched sums size for %s: %zu != %zu\n " , __func__, name.c_str (), (size_t ) nval, e.values .size ());
404+ gguf_free (ctx_gguf);
405+ ggml_free (ctx);
406+ return false ;
382407 }
383- int32_t ncounts = ggml_nelements (counts);
408+
409+ int64_t ncounts = ggml_nelements (counts);
384410 if (e.counts .empty ()) {
385411 e.counts .resize (ncounts, 0 );
386412 } else if (e.counts .size () == 1 && ncounts > 1 ) {
387413 // broadcast, when loading an old imatrix
388414 e.counts .resize (ncounts, e.counts [0 ]);
415+ } else if ((size_t ) ncounts != e.counts .size ()) {
416+ fprintf (stderr, " %s: mismatched counts size for %s: %zu != %zu\n " , __func__, name.c_str (), (size_t ) ncounts, e.counts .size ());
417+ gguf_free (ctx_gguf);
418+ ggml_free (ctx);
419+ return false ;
389420 }
390421
391422 // Recreate the state as expected by save_imatrix()
392- for (int32_t j = 0 ; j < nval; j++) {
423+ for (int64_t j = 0 ; j < nval; j++) {
393424 e.values [j] += ((const float *) sums->data )[j];
394425 }
395- for (int32_t j = 0 ; j < ncounts; j++) {
426+ for (int64_t j = 0 ; j < ncounts; j++) {
396427 e.counts [j] += std::lround (((const float *) counts->data )[j]);
397428 }
398429 }
0 commit comments