Skip to content

Commit 054ede4

Browse files
committed
Refactor function name and update ALLOWED_TENSOR_TYPE
1 parent 3e9f565 commit 054ede4

File tree

1 file changed

+17
-20
lines changed

1 file changed

+17
-20
lines changed

examples/quantize/quantize.cpp

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -252,35 +252,32 @@ static const std::vector<std::string> ALLOWED_TENSOR_TYPE = {
252252
"attn_k",
253253
"attn_kv_a_mqa",
254254
"attn_kv_b",
255-
"attn_out",
255+
"attn_o",
256+
"attn_output",
257+
"attn_q",
256258
"attn_q_a",
257259
"attn_q_b",
258-
"attn_q",
259260
"attn_qkv",
260261
"attn_v",
261262
"channel_mix_key",
262263
"channel_mix_receptance",
263264
"channel_mix_value",
264-
"cls_out",
265265
"cls",
266-
"dec_attn_k",
267-
"dec_attn_out",
268-
"dec_attn_q",
269-
"dec_attn_v",
270-
"dec_cross_attn_k",
271-
"dec_cross_attn_out",
272-
"dec_cross_attn_q",
273-
"dec_cross_attn_v",
266+
"cls.output",
267+
"cross_attn_k",
268+
"cross_attn_o",
269+
"cross_attn_q",
270+
"cross_attn_v",
274271
"ffn_act",
275-
"ffn_down_exp",
276-
"ffn_down_shexp",
277272
"ffn_down",
278-
"ffn_gate_exp",
279-
"ffn_gate_shexp",
273+
"ffn_down_exps",
274+
"ffn_down_shexp",
280275
"ffn_gate",
281-
"ffn_up_exp",
282-
"ffn_up_shexp",
276+
"ffn_gate_exps",
277+
"ffn_gate_shexp",
283278
"ffn_up",
279+
"ffn_up_exps",
280+
"ffn_up_shexp",
284281
"ssm_in",
285282
"ssm_out",
286283
"time_mix_gate",
@@ -296,7 +293,7 @@ struct tensor_quantization {
296293
ggml_type quant = GGML_TYPE_COUNT;
297294
};
298295

299-
static bool string_parse_tensor_type(const char * data, std::vector<tensor_quantization> & tensor_type) {
296+
static bool parse_tensor_type(const char * data, std::vector<tensor_quantization> & tensor_type) {
300297
const char * sep = strchr(data, '=');
301298
if (sep == nullptr) {
302299
printf("\n%s: malformed tensor type '%s'\n\n", __func__, data);
@@ -322,7 +319,7 @@ static bool string_parse_tensor_type(const char * data, std::vector<tensor_quant
322319
bool found = false;
323320
for (const auto & allowed : ALLOWED_TENSOR_TYPE) {
324321
// check if an allowed tensor exists and it's at the end of the kv string
325-
if (tn.length() - allowed.length() == tn.find(allowed)) {
322+
if (tn.length() - allowed.length() == tn.find(allowed) && tn == allowed) {
326323
found = true;
327324
break;
328325
}
@@ -379,7 +376,7 @@ int main(int argc, char ** argv) {
379376
usage(argv[0]);
380377
}
381378
} else if (strcmp(argv[arg_idx], "--tensor-type") == 0) {
382-
if (arg_idx == argc-1 || !string_parse_tensor_type(argv[++arg_idx], tensor_types)) {
379+
if (arg_idx == argc-1 || !parse_tensor_type(argv[++arg_idx], tensor_types)) {
383380
usage(argv[0]);
384381
}
385382
} else if (strcmp(argv[arg_idx], "--override-kv") == 0) {

0 commit comments

Comments
 (0)