@@ -252,35 +252,32 @@ static const std::vector<std::string> ALLOWED_TENSOR_TYPE = {
252252 " attn_k" ,
253253 " attn_kv_a_mqa" ,
254254 " attn_kv_b" ,
255- " attn_out" ,
255+ " attn_o" ,
256+ " attn_output" ,
257+ " attn_q" ,
256258 " attn_q_a" ,
257259 " attn_q_b" ,
258- " attn_q" ,
259260 " attn_qkv" ,
260261 " attn_v" ,
261262 " channel_mix_key" ,
262263 " channel_mix_receptance" ,
263264 " channel_mix_value" ,
264- " cls_out" ,
265265 " cls" ,
266- " dec_attn_k" ,
267- " dec_attn_out" ,
268- " dec_attn_q" ,
269- " dec_attn_v" ,
270- " dec_cross_attn_k" ,
271- " dec_cross_attn_out" ,
272- " dec_cross_attn_q" ,
273- " dec_cross_attn_v" ,
266+ " cls.output" ,
267+ " cross_attn_k" ,
268+ " cross_attn_o" ,
269+ " cross_attn_q" ,
270+ " cross_attn_v" ,
274271 " ffn_act" ,
275- " ffn_down_exp" ,
276- " ffn_down_shexp" ,
277272 " ffn_down" ,
278- " ffn_gate_exp " ,
279- " ffn_gate_shexp " ,
273+ " ffn_down_exps " ,
274+ " ffn_down_shexp " ,
280275 " ffn_gate" ,
281- " ffn_up_exp " ,
282- " ffn_up_shexp " ,
276+ " ffn_gate_exps " ,
277+ " ffn_gate_shexp " ,
283278 " ffn_up" ,
279+ " ffn_up_exps" ,
280+ " ffn_up_shexp" ,
284281 " ssm_in" ,
285282 " ssm_out" ,
286283 " time_mix_gate" ,
@@ -296,7 +293,7 @@ struct tensor_quantization {
296293 ggml_type quant = GGML_TYPE_COUNT;
297294};
298295
299- static bool string_parse_tensor_type (const char * data, std::vector<tensor_quantization> & tensor_type) {
296+ static bool parse_tensor_type (const char * data, std::vector<tensor_quantization> & tensor_type) {
300297 const char * sep = strchr (data, ' =' );
301298 if (sep == nullptr ) {
302299 printf (" \n %s: malformed tensor type '%s'\n\n " , __func__, data);
@@ -322,7 +319,7 @@ static bool string_parse_tensor_type(const char * data, std::vector<tensor_quant
322319 bool found = false ;
323320 for (const auto & allowed : ALLOWED_TENSOR_TYPE) {
324321 // check if an allowed tensor exists and it's at the end of the kv string
325- if (tn.length () - allowed.length () == tn.find (allowed)) {
322+ if (tn.length () - allowed.length () == tn.find (allowed) && tn == allowed ) {
326323 found = true ;
327324 break ;
328325 }
@@ -379,7 +376,7 @@ int main(int argc, char ** argv) {
379376 usage (argv[0 ]);
380377 }
381378 } else if (strcmp (argv[arg_idx], " --tensor-type" ) == 0 ) {
382- if (arg_idx == argc-1 || !string_parse_tensor_type (argv[++arg_idx], tensor_types)) {
379+ if (arg_idx == argc-1 || !parse_tensor_type (argv[++arg_idx], tensor_types)) {
383380 usage (argv[0 ]);
384381 }
385382 } else if (strcmp (argv[arg_idx], " --override-kv" ) == 0 ) {
0 commit comments