@@ -45,60 +45,64 @@ using json = nlohmann::ordered_json;
4545
4646// YAML configuration parsing functions
4747static void parse_yaml_sampling (const YAML::Node& node, common_params_sampling& sampling) {
48- if (node[" seed" ]) sampling.seed = node[" seed" ].as <uint32_t >();
49- if (node[" n_prev" ]) sampling.n_prev = node[" n_prev" ].as <int32_t >();
50- if (node[" n_probs" ]) sampling.n_probs = node[" n_probs" ].as <int32_t >();
51- if (node[" min_keep" ]) sampling.min_keep = node[" min_keep" ].as <int32_t >();
52- if (node[" top_k" ]) sampling.top_k = node[" top_k" ].as <int32_t >();
53- if (node[" top_p" ]) sampling.top_p = node[" top_p" ].as <float >();
54- if (node[" min_p" ]) sampling.min_p = node[" min_p" ].as <float >();
55- if (node[" xtc_probability" ]) sampling.xtc_probability = node[" xtc_probability" ].as <float >();
56- if (node[" xtc_threshold" ]) sampling.xtc_threshold = node[" xtc_threshold" ].as <float >();
57- if (node[" typ_p" ]) sampling.typ_p = node[" typ_p" ].as <float >();
58- if (node[" temp" ]) sampling.temp = node[" temp" ].as <float >();
59- if (node[" dynatemp_range" ]) sampling.dynatemp_range = node[" dynatemp_range" ].as <float >();
60- if (node[" dynatemp_exponent" ]) sampling.dynatemp_exponent = node[" dynatemp_exponent" ].as <float >();
61- if (node[" penalty_last_n" ]) sampling.penalty_last_n = node[" penalty_last_n" ].as <int32_t >();
62- if (node[" penalty_repeat" ]) sampling.penalty_repeat = node[" penalty_repeat" ].as <float >();
63- if (node[" penalty_freq" ]) sampling.penalty_freq = node[" penalty_freq" ].as <float >();
64- if (node[" penalty_present" ]) sampling.penalty_present = node[" penalty_present" ].as <float >();
65- if (node[" dry_multiplier" ]) sampling.dry_multiplier = node[" dry_multiplier" ].as <float >();
66- if (node[" dry_base" ]) sampling.dry_base = node[" dry_base" ].as <float >();
67- if (node[" dry_allowed_length" ]) sampling.dry_allowed_length = node[" dry_allowed_length" ].as <int32_t >();
68- if (node[" dry_penalty_last_n" ]) sampling.dry_penalty_last_n = node[" dry_penalty_last_n" ].as <int32_t >();
69- if (node[" mirostat" ]) sampling.mirostat = node[" mirostat" ].as <int32_t >();
70- if (node[" top_n_sigma" ]) sampling.top_n_sigma = node[" top_n_sigma" ].as <float >();
71- if (node[" mirostat_tau" ]) sampling.mirostat_tau = node[" mirostat_tau" ].as <float >();
72- if (node[" mirostat_eta" ]) sampling.mirostat_eta = node[" mirostat_eta" ].as <float >();
73- if (node[" ignore_eos" ]) sampling.ignore_eos = node[" ignore_eos" ].as <bool >();
74- if (node[" no_perf" ]) sampling.no_perf = node[" no_perf" ].as <bool >();
75- if (node[" timing_per_token" ]) sampling.timing_per_token = node[" timing_per_token" ].as <bool >();
76- if (node[" grammar" ]) sampling.grammar = node[" grammar" ].as <std::string>();
77- if (node[" grammar_lazy" ]) sampling.grammar_lazy = node[" grammar_lazy" ].as <bool >();
48+ if (node[" seed" ] && node[ " seed " ]. IsScalar () ) sampling.seed = node[" seed" ].as <uint32_t >();
49+ if (node[" n_prev" ] && node[ " n_prev " ]. IsScalar () ) sampling.n_prev = node[" n_prev" ].as <int32_t >();
50+ if (node[" n_probs" ] && node[ " n_probs " ]. IsScalar () ) sampling.n_probs = node[" n_probs" ].as <int32_t >();
51+ if (node[" min_keep" ] && node[ " min_keep " ]. IsScalar () ) sampling.min_keep = node[" min_keep" ].as <int32_t >();
52+ if (node[" top_k" ] && node[ " top_k " ]. IsScalar () ) sampling.top_k = node[" top_k" ].as <int32_t >();
53+ if (node[" top_p" ] && node[ " top_p " ]. IsScalar () ) sampling.top_p = node[" top_p" ].as <float >();
54+ if (node[" min_p" ] && node[ " min_p " ]. IsScalar () ) sampling.min_p = node[" min_p" ].as <float >();
55+ if (node[" xtc_probability" ] && node[ " xtc_probability " ]. IsScalar () ) sampling.xtc_probability = node[" xtc_probability" ].as <float >();
56+ if (node[" xtc_threshold" ] && node[ " xtc_threshold " ]. IsScalar () ) sampling.xtc_threshold = node[" xtc_threshold" ].as <float >();
57+ if (node[" typ_p" ] && node[ " typ_p " ]. IsScalar () ) sampling.typ_p = node[" typ_p" ].as <float >();
58+ if (node[" temp" ] && node[ " temp " ]. IsScalar () ) sampling.temp = node[" temp" ].as <float >();
59+ if (node[" dynatemp_range" ] && node[ " dynatemp_range " ]. IsScalar () ) sampling.dynatemp_range = node[" dynatemp_range" ].as <float >();
60+ if (node[" dynatemp_exponent" ] && node[ " dynatemp_exponent " ]. IsScalar () ) sampling.dynatemp_exponent = node[" dynatemp_exponent" ].as <float >();
61+ if (node[" penalty_last_n" ] && node[ " penalty_last_n " ]. IsScalar () ) sampling.penalty_last_n = node[" penalty_last_n" ].as <int32_t >();
62+ if (node[" penalty_repeat" ] && node[ " penalty_repeat " ]. IsScalar () ) sampling.penalty_repeat = node[" penalty_repeat" ].as <float >();
63+ if (node[" penalty_freq" ] && node[ " penalty_freq " ]. IsScalar () ) sampling.penalty_freq = node[" penalty_freq" ].as <float >();
64+ if (node[" penalty_present" ] && node[ " penalty_present " ]. IsScalar () ) sampling.penalty_present = node[" penalty_present" ].as <float >();
65+ if (node[" dry_multiplier" ] && node[ " dry_multiplier " ]. IsScalar () ) sampling.dry_multiplier = node[" dry_multiplier" ].as <float >();
66+ if (node[" dry_base" ] && node[ " dry_base " ]. IsScalar () ) sampling.dry_base = node[" dry_base" ].as <float >();
67+ if (node[" dry_allowed_length" ] && node[ " dry_allowed_length " ]. IsScalar () ) sampling.dry_allowed_length = node[" dry_allowed_length" ].as <int32_t >();
68+ if (node[" dry_penalty_last_n" ] && node[ " dry_penalty_last_n " ]. IsScalar () ) sampling.dry_penalty_last_n = node[" dry_penalty_last_n" ].as <int32_t >();
69+ if (node[" mirostat" ] && node[ " mirostat " ]. IsScalar () ) sampling.mirostat = node[" mirostat" ].as <int32_t >();
70+ if (node[" top_n_sigma" ] && node[ " top_n_sigma " ]. IsScalar () ) sampling.top_n_sigma = node[" top_n_sigma" ].as <float >();
71+ if (node[" mirostat_tau" ] && node[ " mirostat_tau " ]. IsScalar () ) sampling.mirostat_tau = node[" mirostat_tau" ].as <float >();
72+ if (node[" mirostat_eta" ] && node[ " mirostat_eta " ]. IsScalar () ) sampling.mirostat_eta = node[" mirostat_eta" ].as <float >();
73+ if (node[" ignore_eos" ] && node[ " ignore_eos " ]. IsScalar () ) sampling.ignore_eos = node[" ignore_eos" ].as <bool >();
74+ if (node[" no_perf" ] && node[ " no_perf " ]. IsScalar () ) sampling.no_perf = node[" no_perf" ].as <bool >();
75+ if (node[" timing_per_token" ] && node[ " timing_per_token " ]. IsScalar () ) sampling.timing_per_token = node[" timing_per_token" ].as <bool >();
76+ if (node[" grammar" ] && node[ " grammar " ]. IsScalar () ) sampling.grammar = node[" grammar" ].as <std::string>();
77+ if (node[" grammar_lazy" ] && node[ " grammar_lazy " ]. IsScalar () ) sampling.grammar_lazy = node[" grammar_lazy" ].as <bool >();
7878
7979 if (node[" dry_sequence_breakers" ] && node[" dry_sequence_breakers" ].IsSequence ()) {
8080 sampling.dry_sequence_breakers .clear ();
81- for (const auto & breaker : node[" dry_sequence_breakers" ]) {
82- sampling.dry_sequence_breakers .push_back (breaker.as <std::string>());
81+ const auto & breakers = node[" dry_sequence_breakers" ];
82+ sampling.dry_sequence_breakers .reserve (breakers.size ());
83+ for (const auto & breaker : breakers) {
84+ if (breaker && breaker.IsScalar ()) {
85+ sampling.dry_sequence_breakers .push_back (breaker.as <std::string>());
86+ }
8387 }
8488 }
8589}
8690
8791static void parse_yaml_model (const YAML::Node& node, common_params_model& model) {
88- if (node[" path" ]) model.path = node[" path" ].as <std::string>();
89- if (node[" url" ]) model.url = node[" url" ].as <std::string>();
90- if (node[" hf_repo" ]) model.hf_repo = node[" hf_repo" ].as <std::string>();
91- if (node[" hf_file" ]) model.hf_file = node[" hf_file" ].as <std::string>();
92+ if (node[" path" ] && node[ " path " ]. IsScalar () ) model.path = node[" path" ].as <std::string>();
93+ if (node[" url" ] && node[ " url " ]. IsScalar () ) model.url = node[" url" ].as <std::string>();
94+ if (node[" hf_repo" ] && node[ " hf_repo " ]. IsScalar () ) model.hf_repo = node[" hf_repo" ].as <std::string>();
95+ if (node[" hf_file" ] && node[ " hf_file " ]. IsScalar () ) model.hf_file = node[" hf_file" ].as <std::string>();
9296}
9397
9498static void parse_yaml_speculative (const YAML::Node& node, common_params_speculative& spec) {
95- if (node[" n_ctx" ]) spec.n_ctx = node[" n_ctx" ].as <int32_t >();
96- if (node[" n_max" ]) spec.n_max = node[" n_max" ].as <int32_t >();
97- if (node[" n_min" ]) spec.n_min = node[" n_min" ].as <int32_t >();
98- if (node[" n_gpu_layers" ]) spec.n_gpu_layers = node[" n_gpu_layers" ].as <int32_t >();
99- if (node[" p_split" ]) spec.p_split = node[" p_split" ].as <float >();
100- if (node[" p_min" ]) spec.p_min = node[" p_min" ].as <float >();
101- if (node[" cache_type_k" ]) {
99+ if (node[" n_ctx" ] && node[ " n_ctx " ]. IsScalar () ) spec.n_ctx = node[" n_ctx" ].as <int32_t >();
100+ if (node[" n_max" ] && node[ " n_max " ]. IsScalar () ) spec.n_max = node[" n_max" ].as <int32_t >();
101+ if (node[" n_min" ] && node[ " n_min " ]. IsScalar () ) spec.n_min = node[" n_min" ].as <int32_t >();
102+ if (node[" n_gpu_layers" ] && node[ " n_gpu_layers " ]. IsScalar () ) spec.n_gpu_layers = node[" n_gpu_layers" ].as <int32_t >();
103+ if (node[" p_split" ] && node[ " p_split " ]. IsScalar () ) spec.p_split = node[" p_split" ].as <float >();
104+ if (node[" p_min" ] && node[ " p_min " ]. IsScalar () ) spec.p_min = node[" p_min" ].as <float >();
105+ if (node[" cache_type_k" ] && node[ " cache_type_k " ]. IsScalar () ) {
102106 std::string cache_type = node[" cache_type_k" ].as <std::string>();
103107 if (cache_type == " f16" ) spec.cache_type_k = GGML_TYPE_F16;
104108 else if (cache_type == " f32" ) spec.cache_type_k = GGML_TYPE_F32;
@@ -108,7 +112,7 @@ static void parse_yaml_speculative(const YAML::Node& node, common_params_specula
108112 else if (cache_type == " q5_1" ) spec.cache_type_k = GGML_TYPE_Q5_1;
109113 else if (cache_type == " q8_0" ) spec.cache_type_k = GGML_TYPE_Q8_0;
110114 }
111- if (node[" cache_type_v" ]) {
115+ if (node[" cache_type_v" ] && node[ " cache_type_v " ]. IsScalar () ) {
112116 std::string cache_type = node[" cache_type_v" ].as <std::string>();
113117 if (cache_type == " f16" ) spec.cache_type_v = GGML_TYPE_F16;
114118 else if (cache_type == " f32" ) spec.cache_type_v = GGML_TYPE_F32;
@@ -118,28 +122,28 @@ static void parse_yaml_speculative(const YAML::Node& node, common_params_specula
118122 else if (cache_type == " q5_1" ) spec.cache_type_v = GGML_TYPE_Q5_1;
119123 else if (cache_type == " q8_0" ) spec.cache_type_v = GGML_TYPE_Q8_0;
120124 }
121- if (node[" model" ]) {
125+ if (node[" model" ] && node[ " model " ]. IsMap () ) {
122126 parse_yaml_model (node[" model" ], spec.model );
123127 }
124128}
125129
126130static void parse_yaml_vocoder (const YAML::Node& node, common_params_vocoder& vocoder) {
127- if (node[" speaker_file" ]) vocoder.speaker_file = node[" speaker_file" ].as <std::string>();
128- if (node[" use_guide_tokens" ]) vocoder.use_guide_tokens = node[" use_guide_tokens" ].as <bool >();
129- if (node[" model" ]) {
131+ if (node[" speaker_file" ] && node[ " speaker_file " ]. IsScalar () ) vocoder.speaker_file = node[" speaker_file" ].as <std::string>();
132+ if (node[" use_guide_tokens" ] && node[ " use_guide_tokens " ]. IsScalar () ) vocoder.use_guide_tokens = node[" use_guide_tokens" ].as <bool >();
133+ if (node[" model" ] && node[ " model " ]. IsMap () ) {
130134 parse_yaml_model (node[" model" ], vocoder.model );
131135 }
132136}
133137
134138static void parse_yaml_diffusion (const YAML::Node& node, common_params_diffusion& diffusion) {
135- if (node[" steps" ]) diffusion.steps = node[" steps" ].as <int32_t >();
136- if (node[" visual_mode" ]) diffusion.visual_mode = node[" visual_mode" ].as <bool >();
137- if (node[" eps" ]) diffusion.eps = node[" eps" ].as <float >();
138- if (node[" block_length" ]) diffusion.block_length = node[" block_length" ].as <int32_t >();
139- if (node[" algorithm" ]) diffusion.algorithm = node[" algorithm" ].as <int32_t >();
140- if (node[" alg_temp" ]) diffusion.alg_temp = node[" alg_temp" ].as <float >();
141- if (node[" cfg_scale" ]) diffusion.cfg_scale = node[" cfg_scale" ].as <float >();
142- if (node[" add_gumbel_noise" ]) diffusion.add_gumbel_noise = node[" add_gumbel_noise" ].as <bool >();
139+ if (node[" steps" ] && node[ " steps " ]. IsScalar () ) diffusion.steps = node[" steps" ].as <int32_t >();
140+ if (node[" visual_mode" ] && node[ " visual_mode " ]. IsScalar () ) diffusion.visual_mode = node[" visual_mode" ].as <bool >();
141+ if (node[" eps" ] && node[ " eps " ]. IsScalar () ) diffusion.eps = node[" eps" ].as <float >();
142+ if (node[" block_length" ] && node[ " block_length " ]. IsScalar () ) diffusion.block_length = node[" block_length" ].as <int32_t >();
143+ if (node[" algorithm" ] && node[ " algorithm " ]. IsScalar () ) diffusion.algorithm = node[" algorithm" ].as <int32_t >();
144+ if (node[" alg_temp" ] && node[ " alg_temp " ]. IsScalar () ) diffusion.alg_temp = node[" alg_temp" ].as <float >();
145+ if (node[" cfg_scale" ] && node[ " cfg_scale " ]. IsScalar () ) diffusion.cfg_scale = node[" cfg_scale" ].as <float >();
146+ if (node[" add_gumbel_noise" ] && node[ " add_gumbel_noise " ]. IsScalar () ) diffusion.add_gumbel_noise = node[" add_gumbel_noise" ].as <bool >();
143147}
144148
145149static bool load_yaml_config (const std::string& config_path, common_params& params) {
@@ -1504,20 +1508,40 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e
15041508 const common_params params_org = ctx_arg.params ; // the example can modify the default params
15051509
15061510 try {
1511+ bool has_config = false ;
15071512 for (int i = 1 ; i < argc; i++) {
15081513 if (strcmp (argv[i], " --config" ) == 0 && i + 1 < argc) {
15091514 if (!load_yaml_config (argv[i + 1 ], ctx_arg.params )) {
15101515 fprintf (stderr, " Failed to load YAML config: %s\n " , argv[i + 1 ]);
15111516 ctx_arg.params = params_org;
15121517 return false ;
15131518 }
1514- break ;
1519+ has_config = true ;
1520+ break ; // Only process first --config for now
15151521 }
15161522 }
15171523
1518- if (!common_params_parse_ex (argc, argv, ctx_arg)) {
1519- ctx_arg.params = params_org;
1520- return false ;
1524+ if (has_config) {
1525+ std::vector<char *> filtered_argv;
1526+ filtered_argv.push_back (argv[0 ]); // Keep program name
1527+
1528+ for (int i = 1 ; i < argc; i++) {
1529+ if (strcmp (argv[i], " --config" ) == 0 && i + 1 < argc) {
1530+ i++; // Skip both --config and filename
1531+ } else {
1532+ filtered_argv.push_back (argv[i]);
1533+ }
1534+ }
1535+
1536+ if (!common_params_parse_ex (filtered_argv.size (), filtered_argv.data (), ctx_arg)) {
1537+ ctx_arg.params = params_org;
1538+ return false ;
1539+ }
1540+ } else {
1541+ if (!common_params_parse_ex (argc, argv, ctx_arg)) {
1542+ ctx_arg.params = params_org;
1543+ return false ;
1544+ }
15211545 }
15221546 if (ctx_arg.params .usage ) {
15231547 common_params_print_usage (ctx_arg);
0 commit comments