Skip to content

Commit f037443

Browse files
committed
Fix wrong block count bug and implement code readability change
1 parent 4661940 commit f037443

File tree

1 file changed

+8
-12
lines changed

1 file changed

+8
-12
lines changed

src/llama-quant.cpp

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -627,12 +627,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
627627
gguf_set_val_u32(ctx_out.get(), "general.quantization_version", GGML_QNT_VERSION); // TODO: use LLM_KV
628628
gguf_set_val_u32(ctx_out.get(), "general.file_type", ftype); // TODO: use LLM_KV
629629

630-
if (!prune_list.empty()) {
631-
uint32_t block_count = 0;
632-
ml.get_key(LLM_KV_BLOCK_COUNT, block_count);
633-
gguf_set_val_u32(ctx_out.get(), ml.llm_kv(LLM_KV_BLOCK_COUNT).c_str(), block_count - prune_list.size());
634-
}
635-
636630
// Remove split metadata
637631
gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_NO).c_str());
638632
gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str());
@@ -667,12 +661,11 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
667661
for (const auto & it : ml.weights_map) {
668662
const std::string remapped_name(remap_layer(it.first, prune_list, mapped, next_blk_id));
669663
if (remapped_name.empty()) {
670-
if (false
671-
|| it.first.find("attn_v.weight") != std::string::npos
672-
|| it.first.find("attn_qkv.weight") != std::string::npos
673-
|| it.first.find("attn_kv_b.weight")!= std::string::npos) {
674-
pruned_attention_w++;
675-
}
664+
if (it.first.find("attn_v.weight") != std::string::npos ||
665+
it.first.find("attn_qkv.weight") != std::string::npos ||
666+
it.first.find("attn_kv_b.weight") != std::string::npos) {
667+
pruned_attention_w++;
668+
}
676669
LLAMA_LOG_DEBUG("%s: prunning tensor %s\n", __func__, it.first.c_str());
677670
continue;
678671
} else if (remapped_name != it.first) {
@@ -681,6 +674,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
681674
}
682675
tensors.push_back(&it.second);
683676
}
677+
if (!prune_list.empty()) {
678+
gguf_set_val_u32(ctx_out.get(), ml.llm_kv(LLM_KV_BLOCK_COUNT).c_str(), stoi(mapped.rbegin()->second) + 1);
679+
}
684680

685681
// keep_split requires that the weights are sorted by split index
686682
if (params->keep_split) {

0 commit comments

Comments
 (0)