Skip to content

Commit 61bb6e2

Browse files
committed
Remove allowed tensor checks
1 parent 70daa41 commit 61bb6e2

File tree

3 files changed

+17
-77
lines changed

3 files changed

+17
-77
lines changed

src/llama-quant.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@
1414
#include <thread>
1515
#include <unordered_map>
1616

17+
// Quantization types. Changes to this struct must be replicated in quantize.cpp
18+
struct tensor_quantization {
19+
std::string name;
20+
ggml_type quant = GGML_TYPE_COUNT;
21+
};
22+
1723
static void zeros(std::ofstream & file, size_t n) {
1824
char zero = 0;
1925
for (size_t i = 0; i < n; ++i) {
@@ -793,20 +799,14 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
793799
const std::string tensor_name(tensor->name);
794800
for (const auto & [tname, qtype] : tensor_types) {
795801
if (std::regex pattern(tname); std::regex_search(tensor_name, pattern)) {
796-
for (const auto & allowed : ALLOWED_TENSOR_TYPE) {
797-
if (tensor_name.find(allowed) != std::string::npos) {
798-
if (qtype != new_type) {
799-
LLAMA_LOG_DEBUG("(overriding %s) ", ggml_type_name(new_type));
800-
new_type = qtype;
801-
break;
802-
}
803-
}
802+
if (qtype != new_type) {
803+
LLAMA_LOG_DEBUG("(overriding %s) ", ggml_type_name(new_type));
804+
new_type = qtype;
805+
break; // if two or more types are specified for the tensor, first match wins
804806
}
805-
goto loop_exit; // if two or more types are specified for the tensor, first match wins
806807
}
807808
}
808809
}
809-
loop_exit:;
810810
}
811811

812812
if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) {

src/llama-quant.h

Lines changed: 0 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,66 +1 @@
11
#pragma once
2-
3-
#include <string>
4-
#include <vector>
5-
6-
#include "ggml.h"
7-
8-
// Allowed tensors for arbitrary quantization with --tensor-type option
9-
static const std::vector<std::string> ALLOWED_TENSOR_TYPE = {
10-
"attn_k",
11-
"attn_k_b",
12-
"attn_kv_a_mqa",
13-
"attn_kv_b",
14-
"attn_o",
15-
"attn_output",
16-
"attn_q",
17-
"attn_q_a",
18-
"attn_q_b",
19-
"attn_qkv",
20-
"attn_rel_b",
21-
"attn_v",
22-
"attn_v_b",
23-
"channel_mix_key",
24-
"channel_mix_receptance",
25-
"channel_mix_value",
26-
"cls",
27-
"cls.output",
28-
"conv1",
29-
"conv1d",
30-
"conv2",
31-
"cross_attn_k",
32-
"cross_attn_o",
33-
"cross_attn_q",
34-
"cross_attn_rel_b",
35-
"cross_attn_v",
36-
"dw",
37-
"ffn_down",
38-
"ffn_down_exps",
39-
"ffn_down_shexp",
40-
"ffn_gate",
41-
"ffn_gate_exps",
42-
"ffn_gate_shexp",
43-
"ffn_up",
44-
"ffn_up_exps",
45-
"ffn_up_shexp",
46-
"pw1",
47-
"pw1",
48-
"ssm_a",
49-
"ssm_conv1d",
50-
"ssm_dt",
51-
"ssm_in",
52-
"ssm_out",
53-
"ssm_x",
54-
"time_mix_gate",
55-
"time_mix_key",
56-
"time_mix_output",
57-
"time_mix_receptance",
58-
"time_mix_value",
59-
"token_types"
60-
};
61-
62-
// Quantization types
63-
struct tensor_quantization {
64-
std::string name;
65-
ggml_type quant = GGML_TYPE_COUNT;
66-
};

tools/quantize/quantize.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#include "common.h"
22
#include "llama.h"
3-
#include "llama-quant.h"
43

54
#include <cstdio>
65
#include <cstring>
@@ -58,6 +57,12 @@ static const std::vector<quant_option> QUANT_OPTIONS = {
5857
{ "COPY", LLAMA_FTYPE_ALL_F32, "only copy tensors, no quantizing", },
5958
};
6059

60+
// Quantization types. Changes to this struct must be replicated in llama-quantize.cpp
61+
struct tensor_quantization {
62+
std::string name;
63+
ggml_type quant = GGML_TYPE_COUNT;
64+
};
65+
6166
static const char * const LLM_KV_QUANTIZE_IMATRIX_FILE = "quantize.imatrix.file";
6267
static const char * const LLM_KV_QUANTIZE_IMATRIX_DATASET = "quantize.imatrix.dataset";
6368
static const char * const LLM_KV_QUANTIZE_IMATRIX_N_ENTRIES = "quantize.imatrix.entries_count";
@@ -245,7 +250,7 @@ static ggml_type parse_ggml_type(const char * arg) {
245250
return type;
246251
}
247252
}
248-
fprintf(stderr, "%s: invalid ggml_type '%s'\n", __func__, arg);
253+
fprintf(stderr, "\n%s: invalid ggml_type '%s'\n\n", __func__, arg);
249254
return GGML_TYPE_COUNT;
250255
}
251256

0 commit comments

Comments
 (0)