@@ -146,84 +146,166 @@ static bool load_yaml_config(const std::string& config_path, common_params& para
146146 try {
147147 YAML::Node config = YAML::LoadFile (config_path);
148148
149- // Parse main parameters
150- if (config[" n_predict" ]) params.n_predict = config[" n_predict" ].as <int32_t >();
151- if (config[" n_ctx" ]) params.n_ctx = config[" n_ctx" ].as <int32_t >();
152- if (config[" n_batch" ]) params.n_batch = config[" n_batch" ].as <int32_t >();
153- if (config[" n_ubatch" ]) params.n_ubatch = config[" n_ubatch" ].as <int32_t >();
154- if (config[" n_keep" ]) params.n_keep = config[" n_keep" ].as <int32_t >();
155- if (config[" n_chunks" ]) params.n_chunks = config[" n_chunks" ].as <int32_t >();
156- if (config[" n_parallel" ]) params.n_parallel = config[" n_parallel" ].as <int32_t >();
157- if (config[" n_sequences" ]) params.n_sequences = config[" n_sequences" ].as <int32_t >();
158- if (config[" grp_attn_n" ]) params.grp_attn_n = config[" grp_attn_n" ].as <int32_t >();
159- if (config[" grp_attn_w" ]) params.grp_attn_w = config[" grp_attn_w" ].as <int32_t >();
160- if (config[" n_print" ]) params.n_print = config[" n_print" ].as <int32_t >();
161- if (config[" rope_freq_base" ]) params.rope_freq_base = config[" rope_freq_base" ].as <float >();
162- if (config[" rope_freq_scale" ]) params.rope_freq_scale = config[" rope_freq_scale" ].as <float >();
163- if (config[" yarn_ext_factor" ]) params.yarn_ext_factor = config[" yarn_ext_factor" ].as <float >();
164- if (config[" yarn_attn_factor" ]) params.yarn_attn_factor = config[" yarn_attn_factor" ].as <float >();
165- if (config[" yarn_beta_fast" ]) params.yarn_beta_fast = config[" yarn_beta_fast" ].as <float >();
166- if (config[" yarn_beta_slow" ]) params.yarn_beta_slow = config[" yarn_beta_slow" ].as <float >();
167- if (config[" yarn_orig_ctx" ]) params.yarn_orig_ctx = config[" yarn_orig_ctx" ].as <int32_t >();
168- if (config[" n_gpu_layers" ]) params.n_gpu_layers = config[" n_gpu_layers" ].as <int32_t >();
169- if (config[" main_gpu" ]) params.main_gpu = config[" main_gpu" ].as <int32_t >();
170-
171- // Parse string parameters
172- if (config[" model_alias" ]) params.model_alias = config[" model_alias" ].as <std::string>();
173- if (config[" hf_token" ]) params.hf_token = config[" hf_token" ].as <std::string>();
174- if (config[" prompt" ]) params.prompt = config[" prompt" ].as <std::string>();
175- if (config[" system_prompt" ]) params.system_prompt = config[" system_prompt" ].as <std::string>();
176- if (config[" prompt_file" ]) params.prompt_file = config[" prompt_file" ].as <std::string>();
177- if (config[" path_prompt_cache" ]) params.path_prompt_cache = config[" path_prompt_cache" ].as <std::string>();
178- if (config[" input_prefix" ]) params.input_prefix = config[" input_prefix" ].as <std::string>();
179- if (config[" input_suffix" ]) params.input_suffix = config[" input_suffix" ].as <std::string>();
180- if (config[" lookup_cache_static" ]) params.lookup_cache_static = config[" lookup_cache_static" ].as <std::string>();
181- if (config[" lookup_cache_dynamic" ]) params.lookup_cache_dynamic = config[" lookup_cache_dynamic" ].as <std::string>();
182- if (config[" logits_file" ]) params.logits_file = config[" logits_file" ].as <std::string>();
183-
184- // Parse boolean parameters
185- if (config[" lora_init_without_apply" ]) params.lora_init_without_apply = config[" lora_init_without_apply" ].as <bool >();
186- if (config[" offline" ]) params.offline = config[" offline" ].as <bool >();
187-
188- // Parse integer parameters
189- if (config[" verbosity" ]) params.verbosity = config[" verbosity" ].as <int32_t >();
190- if (config[" control_vector_layer_start" ]) params.control_vector_layer_start = config[" control_vector_layer_start" ].as <int32_t >();
191- if (config[" control_vector_layer_end" ]) params.control_vector_layer_end = config[" control_vector_layer_end" ].as <int32_t >();
192- if (config[" ppl_stride" ]) params.ppl_stride = config[" ppl_stride" ].as <int32_t >();
193- if (config[" ppl_output_type" ]) params.ppl_output_type = config[" ppl_output_type" ].as <int32_t >();
194-
195- // Parse array parameters
149+ // Parse main parameters with bounds checking
150+ if (config[" n_predict" ] && config[" n_predict" ].IsScalar ()) {
151+ params.n_predict = config[" n_predict" ].as <int32_t >();
152+ }
153+ if (config[" n_ctx" ] && config[" n_ctx" ].IsScalar ()) {
154+ params.n_ctx = config[" n_ctx" ].as <int32_t >();
155+ }
156+ if (config[" n_batch" ] && config[" n_batch" ].IsScalar ()) {
157+ params.n_batch = config[" n_batch" ].as <int32_t >();
158+ }
159+ if (config[" n_ubatch" ] && config[" n_ubatch" ].IsScalar ()) {
160+ params.n_ubatch = config[" n_ubatch" ].as <int32_t >();
161+ }
162+ if (config[" n_keep" ] && config[" n_keep" ].IsScalar ()) {
163+ params.n_keep = config[" n_keep" ].as <int32_t >();
164+ }
165+ if (config[" n_chunks" ] && config[" n_chunks" ].IsScalar ()) {
166+ params.n_chunks = config[" n_chunks" ].as <int32_t >();
167+ }
168+ if (config[" n_parallel" ] && config[" n_parallel" ].IsScalar ()) {
169+ params.n_parallel = config[" n_parallel" ].as <int32_t >();
170+ }
171+ if (config[" n_sequences" ] && config[" n_sequences" ].IsScalar ()) {
172+ params.n_sequences = config[" n_sequences" ].as <int32_t >();
173+ }
174+ if (config[" grp_attn_n" ] && config[" grp_attn_n" ].IsScalar ()) {
175+ params.grp_attn_n = config[" grp_attn_n" ].as <int32_t >();
176+ }
177+ if (config[" grp_attn_w" ] && config[" grp_attn_w" ].IsScalar ()) {
178+ params.grp_attn_w = config[" grp_attn_w" ].as <int32_t >();
179+ }
180+ if (config[" n_print" ] && config[" n_print" ].IsScalar ()) {
181+ params.n_print = config[" n_print" ].as <int32_t >();
182+ }
183+ if (config[" rope_freq_base" ] && config[" rope_freq_base" ].IsScalar ()) {
184+ params.rope_freq_base = config[" rope_freq_base" ].as <float >();
185+ }
186+ if (config[" rope_freq_scale" ] && config[" rope_freq_scale" ].IsScalar ()) {
187+ params.rope_freq_scale = config[" rope_freq_scale" ].as <float >();
188+ }
189+ if (config[" yarn_ext_factor" ] && config[" yarn_ext_factor" ].IsScalar ()) {
190+ params.yarn_ext_factor = config[" yarn_ext_factor" ].as <float >();
191+ }
192+ if (config[" yarn_attn_factor" ] && config[" yarn_attn_factor" ].IsScalar ()) {
193+ params.yarn_attn_factor = config[" yarn_attn_factor" ].as <float >();
194+ }
195+ if (config[" yarn_beta_fast" ] && config[" yarn_beta_fast" ].IsScalar ()) {
196+ params.yarn_beta_fast = config[" yarn_beta_fast" ].as <float >();
197+ }
198+ if (config[" yarn_beta_slow" ] && config[" yarn_beta_slow" ].IsScalar ()) {
199+ params.yarn_beta_slow = config[" yarn_beta_slow" ].as <float >();
200+ }
201+ if (config[" yarn_orig_ctx" ] && config[" yarn_orig_ctx" ].IsScalar ()) {
202+ params.yarn_orig_ctx = config[" yarn_orig_ctx" ].as <int32_t >();
203+ }
204+ if (config[" n_gpu_layers" ] && config[" n_gpu_layers" ].IsScalar ()) {
205+ params.n_gpu_layers = config[" n_gpu_layers" ].as <int32_t >();
206+ }
207+ if (config[" main_gpu" ] && config[" main_gpu" ].IsScalar ()) {
208+ params.main_gpu = config[" main_gpu" ].as <int32_t >();
209+ }
210+
211+ // Parse string parameters with type checking
212+ if (config[" model_alias" ] && config[" model_alias" ].IsScalar ()) {
213+ params.model_alias = config[" model_alias" ].as <std::string>();
214+ }
215+ if (config[" hf_token" ] && config[" hf_token" ].IsScalar ()) {
216+ params.hf_token = config[" hf_token" ].as <std::string>();
217+ }
218+ if (config[" prompt" ] && config[" prompt" ].IsScalar ()) {
219+ params.prompt = config[" prompt" ].as <std::string>();
220+ }
221+ if (config[" system_prompt" ] && config[" system_prompt" ].IsScalar ()) {
222+ params.system_prompt = config[" system_prompt" ].as <std::string>();
223+ }
224+ if (config[" prompt_file" ] && config[" prompt_file" ].IsScalar ()) {
225+ params.prompt_file = config[" prompt_file" ].as <std::string>();
226+ }
227+ if (config[" path_prompt_cache" ] && config[" path_prompt_cache" ].IsScalar ()) {
228+ params.path_prompt_cache = config[" path_prompt_cache" ].as <std::string>();
229+ }
230+ if (config[" input_prefix" ] && config[" input_prefix" ].IsScalar ()) {
231+ params.input_prefix = config[" input_prefix" ].as <std::string>();
232+ }
233+ if (config[" input_suffix" ] && config[" input_suffix" ].IsScalar ()) {
234+ params.input_suffix = config[" input_suffix" ].as <std::string>();
235+ }
236+ if (config[" lookup_cache_static" ] && config[" lookup_cache_static" ].IsScalar ()) {
237+ params.lookup_cache_static = config[" lookup_cache_static" ].as <std::string>();
238+ }
239+ if (config[" lookup_cache_dynamic" ] && config[" lookup_cache_dynamic" ].IsScalar ()) {
240+ params.lookup_cache_dynamic = config[" lookup_cache_dynamic" ].as <std::string>();
241+ }
242+ if (config[" logits_file" ] && config[" logits_file" ].IsScalar ()) {
243+ params.logits_file = config[" logits_file" ].as <std::string>();
244+ }
245+
246+ // Parse boolean parameters with type checking
247+ if (config[" lora_init_without_apply" ] && config[" lora_init_without_apply" ].IsScalar ()) {
248+ params.lora_init_without_apply = config[" lora_init_without_apply" ].as <bool >();
249+ }
250+ if (config[" offline" ] && config[" offline" ].IsScalar ()) {
251+ params.offline = config[" offline" ].as <bool >();
252+ }
253+
254+ // Parse integer parameters with type checking
255+ if (config[" verbosity" ] && config[" verbosity" ].IsScalar ()) {
256+ params.verbosity = config[" verbosity" ].as <int32_t >();
257+ }
258+ if (config[" control_vector_layer_start" ] && config[" control_vector_layer_start" ].IsScalar ()) {
259+ params.control_vector_layer_start = config[" control_vector_layer_start" ].as <int32_t >();
260+ }
261+ if (config[" control_vector_layer_end" ] && config[" control_vector_layer_end" ].IsScalar ()) {
262+ params.control_vector_layer_end = config[" control_vector_layer_end" ].as <int32_t >();
263+ }
264+ if (config[" ppl_stride" ] && config[" ppl_stride" ].IsScalar ()) {
265+ params.ppl_stride = config[" ppl_stride" ].as <int32_t >();
266+ }
267+ if (config[" ppl_output_type" ] && config[" ppl_output_type" ].IsScalar ()) {
268+ params.ppl_output_type = config[" ppl_output_type" ].as <int32_t >();
269+ }
270+
271+ // Parse array parameters with proper bounds checking
196272 if (config[" in_files" ] && config[" in_files" ].IsSequence ()) {
197273 params.in_files .clear ();
274+ params.in_files .reserve (config[" in_files" ].size ());
198275 for (const auto & file : config[" in_files" ]) {
199- params.in_files .push_back (file.as <std::string>());
276+ if (file.IsScalar ()) {
277+ params.in_files .push_back (file.as <std::string>());
278+ }
200279 }
201280 }
202281
203282 if (config[" antiprompt" ] && config[" antiprompt" ].IsSequence ()) {
204283 params.antiprompt .clear ();
284+ params.antiprompt .reserve (config[" antiprompt" ].size ());
205285 for (const auto & prompt : config[" antiprompt" ]) {
206- params.antiprompt .push_back (prompt.as <std::string>());
286+ if (prompt.IsScalar ()) {
287+ params.antiprompt .push_back (prompt.as <std::string>());
288+ }
207289 }
208290 }
209291
210- if (config[" sampling" ]) {
292+ if (config[" sampling" ] && config[ " sampling " ]. IsMap () ) {
211293 parse_yaml_sampling (config[" sampling" ], params.sampling );
212294 }
213295
214- if (config[" model" ]) {
296+ if (config[" model" ] && config[ " model " ]. IsMap () ) {
215297 parse_yaml_model (config[" model" ], params.model );
216298 }
217299
218- if (config[" speculative" ]) {
300+ if (config[" speculative" ] && config[ " speculative " ]. IsMap () ) {
219301 parse_yaml_speculative (config[" speculative" ], params.speculative );
220302 }
221303
222- if (config[" vocoder" ]) {
304+ if (config[" vocoder" ] && config[ " vocoder " ]. IsMap () ) {
223305 parse_yaml_vocoder (config[" vocoder" ], params.vocoder );
224306 }
225307
226- if (config[" diffusion" ]) {
308+ if (config[" diffusion" ] && config[ " diffusion " ]. IsMap () ) {
227309 parse_yaml_diffusion (config[" diffusion" ], params.diffusion );
228310 }
229311
@@ -234,6 +316,9 @@ static bool load_yaml_config(const std::string& config_path, common_params& para
234316 } catch (const std::exception& e) {
235317 fprintf (stderr, " Error loading YAML config: %s\n " , e.what ());
236318 return false ;
319+ } catch (...) {
320+ fprintf (stderr, " Unknown error loading YAML config\n " );
321+ return false ;
237322 }
238323}
239324
0 commit comments