Skip to content

Commit 6e32244

Browse files
committed
Read statistics from imatrix
1 parent f8863b9 commit 6e32244

File tree

3 files changed

+75
-22
lines changed

3 files changed

+75
-22
lines changed

include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,7 @@ extern "C" {
368368
float target_bpw; // target bits per weight (bpw)
369369
bool keep_bpw_state; // keep bpw state file
370370
void * bpw_state; // pointer to bpw state file
371+
void * statistics; // pointer to statistics data
371372
} llama_model_quantize_params;
372373

373374
typedef struct llama_logit_bias {

src/llama-quant.cpp

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
631631
const std::map<int, std::string> & mapped,
632632
const std::unordered_map<std::string, std::vector<float>> * values_data,
633633
const std::unordered_map<std::string, std::vector<float>> * activations_data,
634+
const std::unordered_map<std::string, std::vector<float>> * statistics_data,
634635
const llama_model_quantize_params * params,
635636
int nthread
636637
) {
@@ -1815,6 +1816,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
18151816
}
18161817
const std::unordered_map<std::string, std::vector<float>> * values_data = nullptr;
18171818
const std::unordered_map<std::string, std::vector<float>> * activations_data = nullptr;
1819+
const std::unordered_map<std::string, std::vector<float>> * statistics_data = nullptr;
18181820
if (params->imatrix) {
18191821
values_data = static_cast<const std::unordered_map<std::string, std::vector<float>>*>(params->imatrix);
18201822
if (values_data) {
@@ -1845,6 +1847,12 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
18451847
}
18461848
}
18471849
}
1850+
if (params->statistics) {
1851+
statistics_data = static_cast<const std::unordered_map<std::string, std::vector<float>>*>(params->statistics);
1852+
if (statistics_data) {
1853+
LLAMA_LOG_INFO(" and %d statistics",int(statistics_data->size()));
1854+
}
1855+
}
18481856
LLAMA_LOG_INFO("\n");
18491857

18501858
gguf_context_ptr ctx_out { gguf_init_empty() };
@@ -1999,15 +2007,18 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
19992007
std::unordered_map<std::string, ggml_type> bpw_overrides = {};
20002008
if (params->target_bpw != -1.0f && !params->only_copy) {
20012009
if (params->imatrix) {
2002-
if (params->activations) {
2003-
LLAMA_LOG_INFO("%s: imatrix with activations provided, target bpw quantization will be more accurate\n",__func__);
2004-
} else {
2005-
LLAMA_LOG_WARN("%s: imatrix without activations provided, target bpw quantization will be less accurate\n", __func__);
2006-
}
2010+
const char* base_msg = params->activations
2011+
? (params->statistics
2012+
? "imatrix with activations and statistics provided, process will be more accurate\n"
2013+
: "imatrix with activations provided, process will be accurate\n")
2014+
: "imatrix without activations provided, process will be less accurate\n";
2015+
if (params->activations) { LLAMA_LOG_INFO("%s: %s", __func__, base_msg); }
2016+
else { LLAMA_LOG_WARN("%s: %s", __func__, base_msg); }
2017+
20072018
LLAMA_LOG_INFO("%s: computing tensor quantization mix to achieve %.4f bpw\n", __func__, params->target_bpw);
2008-
bpw_overrides = target_bpw_type(ml, model, tensors, mapped, values_data, activations_data, params, nthread);
2019+
bpw_overrides = target_bpw_type(ml, model, tensors, mapped, values_data, activations_data, statistics_data, params, nthread);
20092020
} else {
2010-
LLAMA_LOG_WARN("%s: no imatrix provided, target bpw will not apply\n", __func__);
2021+
LLAMA_LOG_WARN("%s: --target-bpw requires an imatrix but none was provided, option will be ignored\n", __func__);
20112022
}
20122023
}
20132024

@@ -2269,7 +2280,8 @@ llama_model_quantize_params llama_model_quantize_default_params() {
22692280
/*.prune_layers =*/ nullptr,
22702281
/*.target_bpw =*/ -1.0f,
22712282
/*.keep_bpw_state =*/ false,
2272-
/*.bpw_state =*/ nullptr
2283+
/*.bpw_state =*/ nullptr,
2284+
/*.statistics =*/ nullptr
22732285
};
22742286

22752287
return result;

tools/quantize/quantize.cpp

Lines changed: 54 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,8 @@ static int load_legacy_imatrix(const std::string & imatrix_file, std::vector<std
221221
static int load_imatrix(const std::string & imatrix_file,
222222
std::vector<std::string> & imatrix_datasets,
223223
std::unordered_map<std::string, std::vector<float>> & values_data,
224-
std::unordered_map<std::string, std::vector<float>> & activations_data) {
224+
std::unordered_map<std::string, std::vector<float>> & activations_data,
225+
std::unordered_map<std::string, std::vector<float>> & statistics_data) {
225226

226227
struct ggml_context * ctx = nullptr;
227228
struct gguf_init_params meta_gguf_params = {
@@ -256,24 +257,28 @@ static int load_imatrix(const std::string & imatrix_file,
256257
const std::string sums_suffix{ ".in_sum" };
257258
const std::string sums2_suffix{ ".in_sum2" };
258259
const std::string counts_suffix{ ".counts" };
260+
const std::string stats_suffix{ ".stats" };
259261

260262
// Using an ordered map to get a deterministic iteration order.
261-
std::map<std::string, std::tuple<struct ggml_tensor *, struct ggml_tensor *, struct ggml_tensor *>> sums_counts_for;
263+
std::map<std::string, std::tuple<struct ggml_tensor *, struct ggml_tensor *, struct ggml_tensor *, struct ggml_tensor *>> sums_counts_for;
262264

263265
for (struct ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) {
264266
std::string name = cur->name;
265267

266268
if (name.empty()) { continue; }
267269

268-
if (string_remove_suffix(name, sums2_suffix)) {
269-
// in_sum2
270+
if (string_remove_suffix(name, sums_suffix)) {
271+
// in_sum
270272
std::get<0>(sums_counts_for[std::move(name)]) = cur;
273+
} else if (string_remove_suffix(name, sums2_suffix)) {
274+
// in_sum2
275+
std::get<1>(sums_counts_for[std::move(name)]) = cur;
271276
} else if (string_remove_suffix(name, counts_suffix)) {
272277
// counts
273-
std::get<1>(sums_counts_for[std::move(name)]) = cur;
274-
} else if (string_remove_suffix(name, sums_suffix)) {
275-
// in_sum
276278
std::get<2>(sums_counts_for[std::move(name)]) = cur;
279+
} else if (string_remove_suffix(name, stats_suffix)) {
280+
// stats
281+
std::get<3>(sums_counts_for[std::move(name)]) = cur;
277282
}
278283
else {
279284
// ignore other tensors
@@ -282,11 +287,12 @@ static int load_imatrix(const std::string & imatrix_file,
282287

283288
for (const auto & sc : sums_counts_for) {
284289
const std::string & name = sc.first;
285-
const struct ggml_tensor * sums = std::get<2>(sc.second);
286-
const struct ggml_tensor * sums2 = std::get<0>(sc.second);
287-
const struct ggml_tensor * counts = std::get<1>(sc.second);
290+
const struct ggml_tensor * sums = std::get<0>(sc.second);
291+
const struct ggml_tensor * sums2 = std::get<1>(sc.second);
292+
const struct ggml_tensor * counts = std::get<2>(sc.second);
293+
const struct ggml_tensor * stats = std::get<3>(sc.second);
288294

289-
// check that sums, sums2 and counts have the same shape
295+
// check sums2 and counts are present, and that sums and sums2 have the same shape
290296
if (!sums2 || !counts || (sums != nullptr && ggml_nelements(sums) != ggml_nelements(sums2))) {
291297
fprintf(stderr, "%s: mismatched sums and counts for %s\n", __func__, name.c_str());
292298
gguf_free(ctx_gguf);
@@ -302,6 +308,19 @@ static int load_imatrix(const std::string & imatrix_file,
302308
if (sums) {
303309
activations.resize(ggml_nelements(sums));
304310
}
311+
if (stats) {
312+
auto & statistics = statistics_data[name];
313+
statistics.resize(ggml_nelements(stats));
314+
if (stats->type == GGML_TYPE_F32) {
315+
std::memcpy(statistics.data(), stats->data, ggml_nelements(stats) * sizeof(float));
316+
} else {
317+
fprintf(stderr, "%s: unsupported .stats type '%s' for '%s' - ignoring entry\n",
318+
__func__, ggml_type_name(stats->type), name.c_str());
319+
statistics.clear();
320+
statistics_data.erase(name);
321+
}
322+
323+
}
305324
values.resize(ggml_nelements(sums2));
306325
float max_count = 0.0f;
307326
for (int64_t j = 0; j < ne1; ++j) {
@@ -354,10 +373,11 @@ static int prepare_imatrix(const std::string & imatrix_file,
354373
const std::vector<std::string> & included_weights,
355374
const std::vector<std::string> & excluded_weights,
356375
std::unordered_map<std::string, std::vector<float>> & values_data,
357-
std::unordered_map<std::string, std::vector<float>> & activations_data) {
376+
std::unordered_map<std::string, std::vector<float>> & activations_data,
377+
std::unordered_map<std::string, std::vector<float>> & statistics_data) {
358378
int m_last_call = -1;
359379
if (!imatrix_file.empty()) {
360-
m_last_call = load_imatrix(imatrix_file, imatrix_dataset, values_data, activations_data);
380+
m_last_call = load_imatrix(imatrix_file, imatrix_dataset, values_data, activations_data, statistics_data);
361381
}
362382
if (values_data.empty()) {
363383
return m_last_call;
@@ -380,11 +400,20 @@ static int prepare_imatrix(const std::string & imatrix_file,
380400
++at;
381401
}
382402
}
403+
for (auto st = statistics_data.begin(); st != statistics_data.end();) {
404+
auto pos = st->first.find(name);
405+
if (pos != std::string::npos) {
406+
st = activations_data.erase(st);
407+
} else {
408+
++st;
409+
}
410+
}
383411
}
384412
}
385413
if (!included_weights.empty()) {
386414
std::unordered_map<std::string, std::vector<float>> tmp_values;
387415
std::unordered_map<std::string, std::vector<float>> tmp_activations;
416+
std::unordered_map<std::string, std::vector<float>> tmp_statistics;
388417
for (const auto & name : included_weights) {
389418
for (auto & e : values_data) {
390419
auto pos = e.first.find(name);
@@ -398,9 +427,16 @@ static int prepare_imatrix(const std::string & imatrix_file,
398427
tmp_activations.emplace(std::move(a));
399428
}
400429
}
430+
for (auto & s : statistics_data) {
431+
auto pos = s.first.find(name);
432+
if (pos != std::string::npos) {
433+
tmp_statistics.emplace(std::move(s));
434+
}
435+
}
401436
}
402437
values_data = std::move(tmp_values);
403438
activations_data = std::move(tmp_activations);
439+
statistics_data = std::move(tmp_statistics);
404440
}
405441

406442
return m_last_call;
@@ -617,7 +653,8 @@ int main(int argc, char ** argv) {
617653
std::vector<std::string> imatrix_datasets;
618654
std::unordered_map<std::string, std::vector<float>> values_data;
619655
std::unordered_map<std::string, std::vector<float>> activations_data;
620-
int m_last_call = prepare_imatrix(imatrix_file, imatrix_datasets, included_weights, excluded_weights, values_data, activations_data);
656+
std::unordered_map<std::string, std::vector<float>> statistics_data;
657+
int m_last_call = prepare_imatrix(imatrix_file, imatrix_datasets, included_weights, excluded_weights, values_data, activations_data, statistics_data);
621658
if (!values_data.empty()) {
622659
params.imatrix = &values_data;
623660
{
@@ -657,6 +694,9 @@ int main(int argc, char ** argv) {
657694
if (!activations_data.empty()) {
658695
params.activations = &activations_data;
659696
}
697+
if (!statistics_data.empty()) {
698+
params.statistics = &statistics_data;
699+
}
660700
if (!kv_overrides.empty()) {
661701
kv_overrides.emplace_back();
662702
kv_overrides.back().key[0] = 0;

0 commit comments

Comments
 (0)