Skip to content

Commit 630750f

Browse files
committed
Validate number of elements if in_sum is present
1 parent 1f72bc1 commit 630750f

File tree

1 file changed

+3
-7
lines changed

1 file changed

+3
-7
lines changed

tools/imatrix/imatrix.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -940,11 +940,11 @@ bool IMatrixCollector::load_imatrix(const char * file_name) {
940940

941941
for (const auto & sc : sums_counts_for) {
942942
const std::string & name = sc.first;
943-
const struct ggml_tensor * in_sum = std::get<2>(sc.second);
943+
const struct ggml_tensor * in_sum = std::get<2>(sc.second);
944944
const struct ggml_tensor * in_sum2 = std::get<0>(sc.second);
945945
const struct ggml_tensor * counts = std::get<1>(sc.second);
946946

947-
if (!in_sum2 || !counts) {
947+
if (!in_sum2 || !counts || (in_sum != nullptr && ggml_nelements(in_sum) != ggml_nelements(in_sum2))) {
948948
LOG_ERR("%s: mismatched sums and counts for %s\n", __func__, name.c_str());
949949
gguf_free(ctx_gguf);
950950
ggml_free(ctx);
@@ -981,16 +981,12 @@ bool IMatrixCollector::load_imatrix(const char * file_name) {
981981

982982
// Recreate the state as expected by save_imatrix()
983983
for (int64_t j = 0; j < nval; j++) {
984+
if (in_sum != nullptr) { e.activations[j] += ((const float *) in_sum->data)[j]; }
984985
e.values[j] += ((const float *) in_sum2->data)[j];
985986
}
986987
for (int64_t j = 0; j < ncounts; j++) {
987988
e.counts[j] += std::lround(((const float *) counts->data)[j]);
988989
}
989-
if (in_sum != nullptr) {
990-
for (int64_t j = 0; j < nval; j++) {
991-
e.activations[j] += ((const float *) in_sum->data)[j];
992-
}
993-
}
994990
}
995991

996992
// TODO: extract into its own method; this is also used by the legacy format

0 commit comments

Comments
 (0)