Skip to content

Commit c128b28

Browse files
committed
Add --prune-layers command line option
1 parent 63aa3f3 commit c128b28

File tree

3 files changed

+55
-15
lines changed

3 files changed

+55
-15
lines changed

examples/quantize/quantize.cpp

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,11 @@ static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftyp
101101
return false;
102102
}
103103

104-
// usage:
105-
// ./llama-quantize [--allow-requantize] [--leave-output-tensor] [--pure] models/llama/ggml-model.gguf [models/llama/ggml-model-quant.gguf] type [nthreads]
106-
//
107104
[[noreturn]]
108105
static void usage(const char * executable) {
109-
printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights] [--exclude-weights] [--output-tensor-type]\n", executable);
110-
printf(" [--token-embedding-type] [--tensor-type] [--keep-split] [--override-kv] model-f32.gguf [model-quant.gguf] type [nthreads]\n\n");
106+
printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights]\n", executable);
107+
printf(" [--exclude-weights] [--output-tensor-type] [--token-embedding-type] [--tensor-type] [--prune-layers] [--keep-split] [--override-kv]\n");
108+
printf(" model-f32.gguf [model-quant.gguf] type [nthreads]\n\n");
111109
printf(" --allow-requantize: Allows requantizing tensors that have already been quantized. Warning: This can severely reduce quality compared to quantizing from 16bit or 32bit\n");
112110
printf(" --leave-output-tensor: Will leave output.weight un(re)quantized. Increases model size but may also increase quality, especially when requantizing\n");
113111
printf(" --pure: Disable k-quant mixtures and quantize all tensors to the same type\n");
@@ -118,6 +116,8 @@ static void usage(const char * executable) {
118116
printf(" --token-embedding-type ggml_type: use this ggml_type for the token embeddings tensor\n");
119117
printf(" --tensor-type TENSOR=TYPE: quantize this tensor to this ggml_type. example: --tensor-type attn_q=q8_0\n");
120118
printf(" Advanced option to selectively quantize tensors. May be specified multiple times.\n");
119+
printf(" --prune-layers L0,L1,L2...comma-separated list of layer numbers to prune from the model\n");
120+
printf(" Advanced option to remove all tensors from the given layers\n");
121121
printf(" --keep-split: will generate quantized model in the same shards as input\n");
122122
printf(" --override-kv KEY=TYPE:VALUE\n");
123123
printf(" Advanced option to override model metadata by key in the quantized model. May be specified multiple times.\n");
@@ -349,6 +349,34 @@ static bool parse_tensor_type(const char * data, std::vector<tensor_quantization
349349
return true;
350350
}
351351

352+
static bool parse_layer_prune(const char * data, std::vector<int> & prune_layers) {
353+
if (!data) {
354+
printf("\n%s: no layer prunning ids provided\n\n", __func__);
355+
return false;
356+
}
357+
358+
const auto block_ids = string_split<std::string>(data, ',');
359+
360+
for ( const auto & block_id : block_ids) {
361+
362+
try {
363+
std::stoi(block_id);
364+
} catch (...) {
365+
printf("%s: invalid layer id '%s'\n\n", __func__, block_id.c_str());
366+
return false;
367+
}
368+
369+
int id = std::stoi(block_id);
370+
if (id < 0) {
371+
printf("\n%s: invalid layer id '%s'\n\n", __func__, block_id.c_str());
372+
return false;
373+
}
374+
prune_layers.emplace_back(id);
375+
}
376+
377+
return true;
378+
}
379+
352380
int main(int argc, char ** argv) {
353381
if (argc < 3) {
354382
usage(argv[0]);
@@ -361,6 +389,7 @@ int main(int argc, char ** argv) {
361389
std::vector<std::string> included_weights, excluded_weights;
362390
std::vector<llama_model_kv_override> kv_overrides;
363391
std::vector<tensor_quantization> tensor_types;
392+
std::vector<int> prune_layers;
364393

365394
for (; arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0; arg_idx++) {
366395
if (strcmp(argv[arg_idx], "--leave-output-tensor") == 0) {
@@ -387,6 +416,10 @@ int main(int argc, char ** argv) {
387416
if (arg_idx == argc-1 || !parse_tensor_type(argv[++arg_idx], tensor_types)) {
388417
usage(argv[0]);
389418
}
419+
} else if (strcmp(argv[arg_idx], "--prune-layers") == 0) {
420+
if (arg_idx == argc-1 || !parse_layer_prune(argv[++arg_idx], prune_layers)) {
421+
usage(argv[0]);
422+
}
390423
} else if (strcmp(argv[arg_idx], "--override-kv") == 0) {
391424
if (arg_idx == argc-1 || !string_parse_kv_override(argv[++arg_idx], kv_overrides)) {
392425
usage(argv[0]);
@@ -474,6 +507,9 @@ int main(int argc, char ** argv) {
474507
if (!tensor_types.empty()) {
475508
params.tensor_types = &tensor_types;
476509
}
510+
if (!prune_layers.empty()) {
511+
params.prune_layers = &prune_layers;
512+
}
477513

478514
llama_backend_init();
479515

include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,7 @@ extern "C" {
379379
void * imatrix; // pointer to importance matrix data
380380
void * kv_overrides; // pointer to vector containing overrides
381381
void * tensor_types; // pointer to vector containing tensor types
382+
void * prune_layers; // pointer to vector containing layer indices to prune
382383
} llama_model_quantize_params;
383384

384385
typedef struct llama_logit_bias {

src/llama-quant.cpp

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@
1313
#include <thread>
1414
#include <unordered_map>
1515

16-
//static std::vector prune_map = {2, 3, 5, 7, 11, 13, 17, 19, 23, 29};
17-
static std::vector<int> prune_map = {3};
18-
1916
static void zeros(std::ofstream & file, size_t n) {
2017
char zero = 0;
2118
for (size_t i = 0; i < n; ++i) {
@@ -64,7 +61,7 @@ static std::string remap_imatrix (const std::string & orig_name, const std::map<
6461

6562
for (const auto & p : mapped) {
6663
if (p.second == blk) {
67-
//LLAMA_LOG_DEBUG("(imatrix -> %d) ", p.first);
64+
LLAMA_LOG_DEBUG("(blk.%d imatrix) ", p.first);
6865
return new_name.replace(match.position(1), match.length(1), std::to_string(p.first));
6966
}
7067
}
@@ -621,14 +618,20 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
621618
const size_t align = GGUF_DEFAULT_ALIGNMENT;
622619
gguf_context_ptr ctx_out { gguf_init_empty() };
623620

621+
std::vector<int> prune_list = {};
622+
if (params->prune_layers) {
623+
prune_list = *static_cast<const std::vector<int> *>(params->prune_layers);
624+
}
625+
624626
// copy the KV pairs from the input file
625627
gguf_set_kv (ctx_out.get(), ml.meta.get());
626628
gguf_set_val_u32(ctx_out.get(), "general.quantization_version", GGML_QNT_VERSION); // TODO: use LLM_KV
627629
gguf_set_val_u32(ctx_out.get(), "general.file_type", ftype); // TODO: use LLM_KV
628630

629-
// ToDo: Add test for --tensor-prune condition
630-
const auto block_count = gguf_get_val_u32(ctx_out.get(), LLM_KV_BLOCK_COUNT) - prune_map.size();
631-
gguf_set_val_u32(ctx_out.get(), ml.llm_kv(LLM_KV_BLOCK_COUNT).c_str(), block_count);
631+
if (!prune_list.empty()) {
632+
const auto block_count = gguf_get_val_u32(ctx_out.get(), LLM_KV_BLOCK_COUNT) - prune_list.size();
633+
gguf_set_val_u32(ctx_out.get(), ml.llm_kv(LLM_KV_BLOCK_COUNT).c_str(), block_count);
634+
}
632635

633636
// Remove split metadata
634637
gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_NO).c_str());
@@ -661,8 +664,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
661664
std::vector<const llama_model_loader::llama_tensor_weight *> tensors;
662665
tensors.reserve(ml.weights_map.size());
663666
for (const auto & it : ml.weights_map) {
664-
// ToDo: Add test for --tensor-prune condition
665-
const std::string remapped_name(remap_layer(it.first, prune_map, mapped, next_blk_id));
667+
const std::string remapped_name(remap_layer(it.first, prune_list, mapped, next_blk_id));
666668
if (remapped_name == "X") {
667669
if (it.first.find("attn_v.weight") != std::string::npos ||
668670
it.first.find("attn_qkv.weight") != std::string::npos ||
@@ -673,7 +675,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
673675
continue;
674676
} else if (remapped_name != it.first) {
675677
ggml_set_name(it.second.tensor, remapped_name.c_str());
676-
//LLAMA_LOG_DEBUG("%s: tensor %s remmaped to %s\n", __func__, it.first.c_str(), ggml_get_name(it.second.tensor));
678+
LLAMA_LOG_DEBUG("%s: tensor %s remapped to %s\n", __func__, it.first.c_str(), ggml_get_name(it.second.tensor));
677679
}
678680
tensors.push_back(&it.second);
679681
}
@@ -1019,6 +1021,7 @@ llama_model_quantize_params llama_model_quantize_default_params() {
10191021
/*.imatrix =*/ nullptr,
10201022
/*.kv_overrides =*/ nullptr,
10211023
/*.tensor_type =*/ nullptr,
1024+
/*.prune_layers =*/ nullptr
10221025
};
10231026

10241027
return result;

0 commit comments

Comments
 (0)