@@ -78,10 +78,10 @@ static void collect_keys(const YAML::Node & node, const std::string & prefix, st
7878static void validate_keys (const YAML::Node & root) {
7979 std::set<std::string> found_keys;
8080 collect_keys (root, " " , found_keys);
81-
81+
8282 const auto valid_keys = get_valid_keys ();
8383 std::vector<std::string> unknown_keys;
84-
84+
8585 for (const auto & key : found_keys) {
8686 if (valid_keys.find (key) == valid_keys.end ()) {
8787 bool is_parent = false ;
@@ -96,7 +96,6 @@ static void validate_keys(const YAML::Node & root) {
9696 }
9797 }
9898 }
99-
10099 if (!unknown_keys.empty ()) {
101100 std::ostringstream ss;
102101 ss << " Unknown YAML keys: " ;
@@ -172,11 +171,11 @@ static common_conversation_mode parse_conversation_mode(const std::string & mode
172171bool common_load_yaml_config (const std::string & path, common_params & params) {
173172 try {
174173 YAML::Node root = YAML::LoadFile (path);
175-
174+
176175 validate_keys (root);
177-
176+
178177 fs::path yaml_dir = fs::absolute (path).parent_path ();
179-
178+
180179 if (root[" model" ]) {
181180 auto model = root[" model" ];
182181 if (model[" path" ]) {
@@ -192,15 +191,14 @@ bool common_load_yaml_config(const std::string & path, common_params & params) {
192191 params.model .hf_file = model[" hf_file" ].as <std::string>();
193192 }
194193 }
195-
194+
196195 if (root[" model_alias" ]) params.model_alias = root[" model_alias" ].as <std::string>();
197196 if (root[" hf_token" ]) params.hf_token = root[" hf_token" ].as <std::string>();
198197 if (root[" prompt" ]) params.prompt = root[" prompt" ].as <std::string>();
199198 if (root[" system_prompt" ]) params.system_prompt = root[" system_prompt" ].as <std::string>();
200199 if (root[" prompt_file" ]) {
201200 params.prompt_file = resolve_path (root[" prompt_file" ].as <std::string>(), yaml_dir);
202201 }
203-
204202 if (root[" n_predict" ]) params.n_predict = root[" n_predict" ].as <int32_t >();
205203 if (root[" n_ctx" ]) params.n_ctx = root[" n_ctx" ].as <int32_t >();
206204 if (root[" n_batch" ]) params.n_batch = root[" n_batch" ].as <int32_t >();
@@ -212,18 +210,17 @@ bool common_load_yaml_config(const std::string & path, common_params & params) {
212210 if (root[" grp_attn_n" ]) params.grp_attn_n = root[" grp_attn_n" ].as <int32_t >();
213211 if (root[" grp_attn_w" ]) params.grp_attn_w = root[" grp_attn_w" ].as <int32_t >();
214212 if (root[" n_print" ]) params.n_print = root[" n_print" ].as <int32_t >();
215-
216213 if (root[" rope_freq_base" ]) params.rope_freq_base = root[" rope_freq_base" ].as <float >();
217214 if (root[" rope_freq_scale" ]) params.rope_freq_scale = root[" rope_freq_scale" ].as <float >();
218215 if (root[" yarn_ext_factor" ]) params.yarn_ext_factor = root[" yarn_ext_factor" ].as <float >();
219216 if (root[" yarn_attn_factor" ]) params.yarn_attn_factor = root[" yarn_attn_factor" ].as <float >();
220217 if (root[" yarn_beta_fast" ]) params.yarn_beta_fast = root[" yarn_beta_fast" ].as <float >();
221218 if (root[" yarn_beta_slow" ]) params.yarn_beta_slow = root[" yarn_beta_slow" ].as <float >();
222219 if (root[" yarn_orig_ctx" ]) params.yarn_orig_ctx = root[" yarn_orig_ctx" ].as <int32_t >();
223-
220+
224221 if (root[" n_gpu_layers" ]) params.n_gpu_layers = root[" n_gpu_layers" ].as <int32_t >();
225222 if (root[" main_gpu" ]) params.main_gpu = root[" main_gpu" ].as <int32_t >();
226-
223+
227224 if (root[" split_mode" ]) {
228225 params.split_mode = parse_split_mode (root[" split_mode" ].as <std::string>());
229226 }
@@ -242,7 +239,7 @@ bool common_load_yaml_config(const std::string & path, common_params & params) {
242239 if (root[" conversation_mode" ]) {
243240 params.conversation_mode = parse_conversation_mode (root[" conversation_mode" ].as <std::string>());
244241 }
245-
242+
246243 if (root[" use_mmap" ]) params.use_mmap = root[" use_mmap" ].as <bool >();
247244 if (root[" use_mlock" ]) params.use_mlock = root[" use_mlock" ].as <bool >();
248245 if (root[" verbose_prompt" ]) params.verbose_prompt = root[" verbose_prompt" ].as <bool >();
@@ -255,7 +252,7 @@ bool common_load_yaml_config(const std::string & path, common_params & params) {
255252 if (root[" simple_io" ]) params.simple_io = root[" simple_io" ].as <bool >();
256253 if (root[" interactive" ]) params.interactive = root[" interactive" ].as <bool >();
257254 if (root[" interactive_first" ]) params.interactive_first = root[" interactive_first" ].as <bool >();
258-
255+
259256 if (root[" input_prefix" ]) params.input_prefix = root[" input_prefix" ].as <std::string>();
260257 if (root[" input_suffix" ]) params.input_suffix = root[" input_suffix" ].as <std::string>();
261258 if (root[" logits_file" ]) {
@@ -264,39 +261,39 @@ bool common_load_yaml_config(const std::string & path, common_params & params) {
264261 if (root[" path_prompt_cache" ]) {
265262 params.path_prompt_cache = resolve_path (root[" path_prompt_cache" ].as <std::string>(), yaml_dir);
266263 }
267-
264+
268265 if (root[" cache_type_k" ]) {
269266 params.cache_type_k = parse_ggml_type (root[" cache_type_k" ].as <std::string>());
270267 }
271268 if (root[" cache_type_v" ]) {
272269 params.cache_type_v = parse_ggml_type (root[" cache_type_v" ].as <std::string>());
273270 }
274-
271+
275272 if (root[" antiprompt" ]) {
276273 params.antiprompt .clear ();
277274 for (const auto & item : root[" antiprompt" ]) {
278275 params.antiprompt .push_back (item.as <std::string>());
279276 }
280277 }
281-
278+
282279 if (root[" in_files" ]) {
283280 params.in_files .clear ();
284281 for (const auto & item : root[" in_files" ]) {
285282 params.in_files .push_back (resolve_path (item.as <std::string>(), yaml_dir));
286283 }
287284 }
288-
285+
289286 if (root[" image" ]) {
290287 params.image .clear ();
291288 for (const auto & item : root[" image" ]) {
292289 params.image .push_back (resolve_path (item.as <std::string>(), yaml_dir));
293290 }
294291 }
295-
292+
296293 if (root[" seed" ]) {
297294 params.sampling .seed = root[" seed" ].as <uint32_t >();
298295 }
299-
296+
300297 if (root[" sampling" ]) {
301298 auto sampling = root[" sampling" ];
302299 if (sampling[" seed" ]) params.sampling .seed = sampling[" seed" ].as <uint32_t >();
@@ -330,7 +327,7 @@ bool common_load_yaml_config(const std::string & path, common_params & params) {
330327 if (sampling[" grammar" ]) params.sampling .grammar = sampling[" grammar" ].as <std::string>();
331328 if (sampling[" grammar_lazy" ]) params.sampling .grammar_lazy = sampling[" grammar_lazy" ].as <bool >();
332329 }
333-
330+
334331 return true ;
335332 } catch (const YAML::Exception & e) {
336333 throw std::invalid_argument (" YAML parsing error: " + std::string (e.what ()));
0 commit comments