@@ -346,30 +346,25 @@ void parse_args(int argc, const char** argv, SDParams& params) {
346346 invalid_arg = true ;
347347 break ;
348348 }
349- std::string type = argv[i];
350- if (type == " f32" ) {
351- params.wtype = SD_TYPE_F32;
352- } else if (type == " f16" ) {
353- params.wtype = SD_TYPE_F16;
354- } else if (type == " q4_0" ) {
355- params.wtype = SD_TYPE_Q4_0;
356- } else if (type == " q4_1" ) {
357- params.wtype = SD_TYPE_Q4_1;
358- } else if (type == " q5_0" ) {
359- params.wtype = SD_TYPE_Q5_0;
360- } else if (type == " q5_1" ) {
361- params.wtype = SD_TYPE_Q5_1;
362- } else if (type == " q8_0" ) {
363- params.wtype = SD_TYPE_Q8_0;
364- } else if (type == " q2_k" ) {
365- params.wtype = SD_TYPE_Q2_K;
366- } else if (type == " q3_k" ) {
367- params.wtype = SD_TYPE_Q3_K;
368- } else if (type == " q4_k" ) {
369- params.wtype = SD_TYPE_Q4_K;
370- } else {
371- fprintf (stderr, " error: invalid weight format %s, must be one of [f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_k, q3_k, q4_k]\n " ,
372- type.c_str ());
349+ std::string type = argv[i];
350+ bool found = false ;
351+ std::string valid_types = " " ;
352+ for (size_t i = 0 ; i < SD_TYPE_COUNT; i++) {
353+ auto trait = ggml_get_type_traits ((ggml_type)i);
354+ std::string name (trait->type_name );
355+ if (i)
356+ valid_types += " , " ;
357+ valid_types += name;
358+ if (type == name) {
359+ params.wtype = (enum sd_type_t )i;
360+ found = true ;
361+ break ;
362+ }
363+ }
364+ if (!found) {
365+ fprintf (stderr, " error: invalid weight format %s, must be one of [%s]\n " ,
366+ type.c_str (),
367+ valid_types.c_str ());
373368 exit (1 );
374369 }
375370 } else if (arg == " --lora-model-dir" ) {
0 commit comments