Skip to content

Commit fdc466e

Browse files
committed
cli: allow any wtype
1 parent 4570715 commit fdc466e

File tree

1 file changed

+19
-24
lines changed

1 file changed

+19
-24
lines changed

examples/cli/main.cpp

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -346,30 +346,25 @@ void parse_args(int argc, const char** argv, SDParams& params) {
346346
invalid_arg = true;
347347
break;
348348
}
349-
std::string type = argv[i];
350-
if (type == "f32") {
351-
params.wtype = SD_TYPE_F32;
352-
} else if (type == "f16") {
353-
params.wtype = SD_TYPE_F16;
354-
} else if (type == "q4_0") {
355-
params.wtype = SD_TYPE_Q4_0;
356-
} else if (type == "q4_1") {
357-
params.wtype = SD_TYPE_Q4_1;
358-
} else if (type == "q5_0") {
359-
params.wtype = SD_TYPE_Q5_0;
360-
} else if (type == "q5_1") {
361-
params.wtype = SD_TYPE_Q5_1;
362-
} else if (type == "q8_0") {
363-
params.wtype = SD_TYPE_Q8_0;
364-
} else if (type == "q2_k") {
365-
params.wtype = SD_TYPE_Q2_K;
366-
} else if (type == "q3_k") {
367-
params.wtype = SD_TYPE_Q3_K;
368-
} else if (type == "q4_k") {
369-
params.wtype = SD_TYPE_Q4_K;
370-
} else {
371-
fprintf(stderr, "error: invalid weight format %s, must be one of [f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_k, q3_k, q4_k]\n",
372-
type.c_str());
349+
std::string type = argv[i];
350+
bool found = false;
351+
std::string valid_types = "";
352+
for (size_t i = 0; i < SD_TYPE_COUNT; i++) {
353+
auto trait = ggml_get_type_traits((ggml_type)i);
354+
std::string name(trait->type_name);
355+
if (i)
356+
valid_types += ", ";
357+
valid_types += name;
358+
if (type == name) {
359+
params.wtype = (enum sd_type_t)i;
360+
found = true;
361+
break;
362+
}
363+
}
364+
if (!found) {
365+
fprintf(stderr, "error: invalid weight format %s, must be one of [%s]\n",
366+
type.c_str(),
367+
valid_types.c_str());
373368
exit(1);
374369
}
375370
} else if (arg == "--lora-model-dir") {

0 commit comments

Comments
 (0)