@@ -1095,6 +1095,7 @@ static void common_params_print_completion(common_params_context & ctx_arg) {
10951095 " llama-embedding" ,
10961096 " llama-eval-callback" ,
10971097 " llama-export-lora" ,
1098+ " llama-finetune" ,
10981099 " llama-gen-docs" ,
10991100 " llama-gguf" ,
11001101 " llama-gguf-hash" ,
@@ -1239,6 +1240,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
12391240 sampler_type_names.pop_back ();
12401241
12411242
1243+ params.optimize = ggml_opt_get_default_optimizer_params (NULL );
1244+ params.optimize .alpha = 1e-8 ; // default 1e-3 is much too high for LLAMA_EXAMPLE_FINETUNE
1245+
12421246 /* *
12431247 * filter options by example
12441248 * rules:
@@ -1472,14 +1476,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
14721476 [](common_params & params) {
14731477 params.ctx_shift = false ;
14741478 }
1475- ).set_examples ({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY}).set_env (" LLAMA_ARG_NO_CONTEXT_SHIFT" ));
1479+ ).set_examples ({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE }).set_env (" LLAMA_ARG_NO_CONTEXT_SHIFT" ));
14761480 add_opt (common_arg (
14771481 {" --chunks" }, " N" ,
14781482 string_format (" max number of chunks to process (default: %d, -1 = all)" , params.n_chunks ),
14791483 [](common_params & params, int value) {
14801484 params.n_chunks = value;
14811485 }
1482- ).set_examples ({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_RETRIEVAL}));
1486+ ).set_examples ({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE, LLAMA_EXAMPLE_RETRIEVAL}));
14831487 add_opt (common_arg (
14841488 {" -fa" , " --flash-attn" },
14851489 string_format (" enable Flash Attention (default: %s)" , params.flash_attn ? " enabled" : " disabled" ),
@@ -2117,70 +2121,88 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
21172121 [](common_params & params) {
21182122 params.hellaswag = true ;
21192123 }
2120- ).set_examples ({LLAMA_EXAMPLE_PERPLEXITY}));
2124+ ).set_examples ({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE }));
21212125 add_opt (common_arg (
21222126 {" --hellaswag-tasks" }, " N" ,
21232127 string_format (" number of tasks to use when computing the HellaSwag score (default: %zu)" , params.hellaswag_tasks ),
21242128 [](common_params & params, int value) {
21252129 params.hellaswag_tasks = value;
21262130 }
2127- ).set_examples ({LLAMA_EXAMPLE_PERPLEXITY}));
2131+ ).set_examples ({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE }));
21282132 add_opt (common_arg (
21292133 {" --winogrande" },
21302134 " compute Winogrande score over random tasks from datafile supplied with -f" ,
21312135 [](common_params & params) {
21322136 params.winogrande = true ;
21332137 }
2134- ).set_examples ({LLAMA_EXAMPLE_PERPLEXITY}));
2138+ ).set_examples ({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE }));
21352139 add_opt (common_arg (
21362140 {" --winogrande-tasks" }, " N" ,
21372141 string_format (" number of tasks to use when computing the Winogrande score (default: %zu)" , params.winogrande_tasks ),
21382142 [](common_params & params, int value) {
21392143 params.winogrande_tasks = value;
21402144 }
2141- ).set_examples ({LLAMA_EXAMPLE_PERPLEXITY}));
2145+ ).set_examples ({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE }));
21422146 add_opt (common_arg (
21432147 {" --multiple-choice" },
21442148 " compute multiple choice score over random tasks from datafile supplied with -f" ,
21452149 [](common_params & params) {
21462150 params.multiple_choice = true ;
21472151 }
2148- ).set_examples ({LLAMA_EXAMPLE_PERPLEXITY}));
2152+ ).set_examples ({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE }));
21492153 add_opt (common_arg (
21502154 {" --multiple-choice-tasks" }, " N" ,
21512155 string_format (" number of tasks to use when computing the multiple choice score (default: %zu)" , params.multiple_choice_tasks ),
21522156 [](common_params & params, int value) {
21532157 params.multiple_choice_tasks = value;
21542158 }
2155- ).set_examples ({LLAMA_EXAMPLE_PERPLEXITY}));
2159+ ).set_examples ({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE }));
21562160 add_opt (common_arg (
21572161 {" --kl-divergence" },
21582162 " computes KL-divergence to logits provided via --kl-divergence-base" ,
21592163 [](common_params & params) {
21602164 params.kl_divergence = true ;
21612165 }
2162- ).set_examples ({LLAMA_EXAMPLE_PERPLEXITY}));
2166+ ).set_examples ({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE }));
21632167 add_opt (common_arg (
21642168 {" --save-all-logits" , " --kl-divergence-base" }, " FNAME" ,
21652169 " set logits file" ,
21662170 [](common_params & params, const std::string & value) {
21672171 params.logits_file = value;
21682172 }
2169- ).set_examples ({LLAMA_EXAMPLE_PERPLEXITY}));
2173+ ).set_examples ({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE }));
21702174 add_opt (common_arg (
21712175 {" --ppl-stride" }, " N" ,
21722176 string_format (" stride for perplexity calculation (default: %d)" , params.ppl_stride ),
21732177 [](common_params & params, int value) {
21742178 params.ppl_stride = value;
21752179 }
2176- ).set_examples ({LLAMA_EXAMPLE_PERPLEXITY}));
2180+ ).set_examples ({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE }));
21772181 add_opt (common_arg (
21782182 {" --ppl-output-type" }, " <0|1>" ,
21792183 string_format (" output type for perplexity calculation (default: %d)" , params.ppl_output_type ),
21802184 [](common_params & params, int value) {
21812185 params.ppl_output_type = value;
21822186 }
2183- ).set_examples ({LLAMA_EXAMPLE_PERPLEXITY}));
2187+ ).set_examples ({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE}));
2188+ add_opt (common_arg (
2189+ {" -lr" , " -alpha" , " --alpha" , " --learning-rate" }, " ALPHA" ,
2190+ string_format (" adamw optimizer alpha (default: %.1f)" , (double )params.optimize .adamw .alpha ),
2191+ [](common_params & params, const std::string & value) {
2192+ params.optimize .adamw .alpha = std::stof (value);
2193+ }
2194+ ).set_examples ({LLAMA_EXAMPLE_FINETUNE}));
2195+ add_opt (common_arg (
2196+ {" -opt" , " --optimizer" }, " N" ,
2197+ " adamw (N=0) or //TODO:SGD (N=1)" ,
2198+ [](common_params & params, int N) {
2199+ if (N == GGML_OPT_OPTIMIZER_SGD)
2200+ throw std::invalid_argument (" TODO: implement SGD" );
2201+ if (N >= GGML_OPT_OPTIMIZER_COUNT)
2202+ throw std::invalid_argument (" invalid --optimizer N (try 0)" );
2203+ params.optimize .optimizer = (enum ggml_opt_optimizer)N;
2204+ }
2205+ ).set_examples ({LLAMA_EXAMPLE_FINETUNE}));
21842206 add_opt (common_arg (
21852207 {" -dt" , " --defrag-thold" }, " N" ,
21862208 string_format (" KV cache defragmentation threshold (default: %.1f, < 0 - disabled)" , (double )params.defrag_thold ),
0 commit comments