@@ -145,6 +145,45 @@ static void common_params_handle_model_default(common_params & params) {
145145 }
146146}
147147
148+ const std::initializer_list<std::pair<const char *, ggml_type>> kv_cache_types = {
149+ {" f32" , GGML_TYPE_F32},
150+ {" f16" , GGML_TYPE_F16},
151+ {" bf16" , GGML_TYPE_BF16},
152+ {" q8_0" , GGML_TYPE_Q8_0},
153+ {" q4_0" , GGML_TYPE_Q4_0},
154+ {" q4_1" , GGML_TYPE_Q4_1},
155+ {" iq4_nl" , GGML_TYPE_IQ4_NL},
156+ {" q5_0" , GGML_TYPE_Q5_0},
157+ {" q5_1" , GGML_TYPE_Q5_1},
158+ };
159+
160+ static ggml_type kv_cache_type_from_str (const std::string & s) {
161+ for (const auto & kv : kv_cache_types) {
162+ if (kv.first == s) {
163+ return kv.second ;
164+ }
165+ }
166+ throw std::runtime_error (" Unsupported cache type: " + s);
167+ }
168+
169+ static const char * kv_cache_type_to_str (const ggml_type t) {
170+ for (const auto & kv : kv_cache_types) {
171+ if (kv.second == t) {
172+ return kv.first ;
173+ }
174+ }
175+ throw std::runtime_error (" Unsupported cache type: " + std::to_string (t));
176+ }
177+
178+ static std::string get_all_kv_cache_types () {
179+ std::ostringstream msg;
180+ size_t size = kv_cache_types.size ();
181+ for (size_t i = 0 ; i < size; i++) {
182+ msg << (kv_cache_types.begin () + i)->first << (i+1 == size ? " " : " , " );
183+ }
184+ return msg.str ();
185+ }
186+
148187//
149188// CLI argument parsing functions
150189//
@@ -1174,18 +1213,28 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
11741213 ).set_env (" LLAMA_ARG_NO_KV_OFFLOAD" ));
11751214 add_opt (common_arg (
11761215 {" -ctk" , " --cache-type-k" }, " TYPE" ,
1177- string_format (" KV cache data type for K (default: %s)" , params.cache_type_k .c_str ()),
1216+ string_format (
1217+ " KV cache data type for K\n "
1218+ " allowed values: %s\n "
1219+ " (default: %s)" ,
1220+ get_all_kv_cache_types ().c_str (),
1221+ kv_cache_type_to_str (params.cache_type_k )
1222+ ),
11781223 [](common_params & params, const std::string & value) {
1179- // TODO: get the type right here
1180- params.cache_type_k = value;
1224+ params.cache_type_k = kv_cache_type_from_str (value);
11811225 }
11821226 ).set_env (" LLAMA_ARG_CACHE_TYPE_K" ));
11831227 add_opt (common_arg (
11841228 {" -ctv" , " --cache-type-v" }, " TYPE" ,
1185- string_format (" KV cache data type for V (default: %s)" , params.cache_type_v .c_str ()),
1229+ string_format (
1230+ " KV cache data type for V\n "
1231+ " allowed values: %s\n "
1232+ " (default: %s)" ,
1233+ get_all_kv_cache_types ().c_str (),
1234+ kv_cache_type_to_str (params.cache_type_v )
1235+ ),
11861236 [](common_params & params, const std::string & value) {
1187- // TODO: get the type right here
1188- params.cache_type_v = value;
1237+ params.cache_type_v = kv_cache_type_from_str (value);
11891238 }
11901239 ).set_env (" LLAMA_ARG_CACHE_TYPE_V" ));
11911240 add_opt (common_arg (
0 commit comments