Skip to content

Commit ba35f29

Browse files
committed
common : improve ctv ctk cli argument
1 parent 8faa1d4 commit ba35f29

File tree

3 files changed

+59
-42
lines changed

3 files changed

+59
-42
lines changed

common/arg.cpp

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,45 @@ static void common_params_handle_model_default(common_params & params) {
145145
}
146146
}
147147

148+
const std::initializer_list<std::pair<const char *, ggml_type>> kv_cache_types = {
149+
{"f32", GGML_TYPE_F32},
150+
{"f16", GGML_TYPE_F16},
151+
{"bf16", GGML_TYPE_BF16},
152+
{"q8_0", GGML_TYPE_Q8_0},
153+
{"q4_0", GGML_TYPE_Q4_0},
154+
{"q4_1", GGML_TYPE_Q4_1},
155+
{"iq4_nl", GGML_TYPE_IQ4_NL},
156+
{"q5_0", GGML_TYPE_Q5_0},
157+
{"q5_1", GGML_TYPE_Q5_1},
158+
};
159+
160+
static ggml_type kv_cache_type_from_str(const std::string & s) {
161+
for (const auto & kv : kv_cache_types) {
162+
if (kv.first == s) {
163+
return kv.second;
164+
}
165+
}
166+
throw std::runtime_error("Unsupported cache type: " + s);
167+
}
168+
169+
static const char * kv_cache_type_to_str(const ggml_type t) {
170+
for (const auto & kv : kv_cache_types) {
171+
if (kv.second == t) {
172+
return kv.first;
173+
}
174+
}
175+
throw std::runtime_error("Unsupported cache type: " + std::to_string(t));
176+
}
177+
178+
static std::string get_all_kv_cache_types() {
179+
std::ostringstream msg;
180+
size_t size = kv_cache_types.size();
181+
for (size_t i = 0; i < size; i++) {
182+
msg << (kv_cache_types.begin() + i)->first << (i+1 == size ? "" : ", ");
183+
}
184+
return msg.str();
185+
}
186+
148187
//
149188
// CLI argument parsing functions
150189
//
@@ -1174,18 +1213,28 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
11741213
).set_env("LLAMA_ARG_NO_KV_OFFLOAD"));
11751214
add_opt(common_arg(
11761215
{"-ctk", "--cache-type-k"}, "TYPE",
1177-
string_format("KV cache data type for K (default: %s)", params.cache_type_k.c_str()),
1216+
string_format(
1217+
"KV cache data type for K\n"
1218+
"allowed values: %s\n"
1219+
"(default: %s)",
1220+
get_all_kv_cache_types().c_str(),
1221+
kv_cache_type_to_str(params.cache_type_k)
1222+
),
11781223
[](common_params & params, const std::string & value) {
1179-
// TODO: get the type right here
1180-
params.cache_type_k = value;
1224+
params.cache_type_k = kv_cache_type_from_str(value);
11811225
}
11821226
).set_env("LLAMA_ARG_CACHE_TYPE_K"));
11831227
add_opt(common_arg(
11841228
{"-ctv", "--cache-type-v"}, "TYPE",
1185-
string_format("KV cache data type for V (default: %s)", params.cache_type_v.c_str()),
1229+
string_format(
1230+
"KV cache data type for V\n"
1231+
"allowed values: %s\n"
1232+
"(default: %s)",
1233+
get_all_kv_cache_types().c_str(),
1234+
kv_cache_type_to_str(params.cache_type_v)
1235+
),
11861236
[](common_params & params, const std::string & value) {
1187-
// TODO: get the type right here
1188-
params.cache_type_v = value;
1237+
params.cache_type_v = kv_cache_type_from_str(value);
11891238
}
11901239
).set_env("LLAMA_ARG_CACHE_TYPE_V"));
11911240
add_opt(common_arg(

common/common.cpp

Lines changed: 2 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,38 +1015,6 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
10151015
return mparams;
10161016
}
10171017

1018-
static ggml_type kv_cache_type_from_str(const std::string & s) {
1019-
if (s == "f32") {
1020-
return GGML_TYPE_F32;
1021-
}
1022-
if (s == "f16") {
1023-
return GGML_TYPE_F16;
1024-
}
1025-
if (s == "bf16") {
1026-
return GGML_TYPE_BF16;
1027-
}
1028-
if (s == "q8_0") {
1029-
return GGML_TYPE_Q8_0;
1030-
}
1031-
if (s == "q4_0") {
1032-
return GGML_TYPE_Q4_0;
1033-
}
1034-
if (s == "q4_1") {
1035-
return GGML_TYPE_Q4_1;
1036-
}
1037-
if (s == "iq4_nl") {
1038-
return GGML_TYPE_IQ4_NL;
1039-
}
1040-
if (s == "q5_0") {
1041-
return GGML_TYPE_Q5_0;
1042-
}
1043-
if (s == "q5_1") {
1044-
return GGML_TYPE_Q5_1;
1045-
}
1046-
1047-
throw std::runtime_error("Unsupported cache type: " + s);
1048-
}
1049-
10501018
struct llama_context_params common_context_params_to_llama(const common_params & params) {
10511019
auto cparams = llama_context_default_params();
10521020

@@ -1081,8 +1049,8 @@ struct llama_context_params common_context_params_to_llama(const common_params &
10811049
cparams.pooling_type = LLAMA_POOLING_TYPE_RANK;
10821050
}
10831051

1084-
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
1085-
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
1052+
cparams.type_k = params.cache_type_k;
1053+
cparams.type_v = params.cache_type_v;
10861054

10871055
return cparams;
10881056
}

common/common.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,8 @@ struct common_params {
286286
bool warmup = true; // warmup run
287287
bool check_tensors = false; // validate tensor data
288288

289-
std::string cache_type_k = "f16"; // KV cache data type for the K
290-
std::string cache_type_v = "f16"; // KV cache data type for the V
289+
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
290+
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
291291

292292
// multimodal models (see examples/llava)
293293
std::string mmproj = ""; // path to multimodal projector // NOLINT

0 commit comments

Comments
 (0)