Skip to content

Commit 5a22882

Browse files
committed
finetune.cpp command-line arg
add to ggml-opt learning rate (adamw alpha) cmdline arg, and an optimizer enum defaulting to adamw, preparatory to work to support SGD these are in common args a set of optimizer options active only for the new FINETUNE example (which includes all the previous finetune.cpp PERPLEXITY options as a precaution) perhaps breaking with precedent, the ggml_opt_optimizer_params struct is included directly as args - if desired, we can instead just add learning rate and optimizer type to a struct independent of ggml-opt.h as proposed in #13835
1 parent e0e3aa2 commit 5a22882

File tree

5 files changed

+51
-15
lines changed

5 files changed

+51
-15
lines changed

common/arg.cpp

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1095,6 +1095,7 @@ static void common_params_print_completion(common_params_context & ctx_arg) {
10951095
"llama-embedding",
10961096
"llama-eval-callback",
10971097
"llama-export-lora",
1098+
"llama-finetune",
10981099
"llama-gen-docs",
10991100
"llama-gguf",
11001101
"llama-gguf-hash",
@@ -1239,6 +1240,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
12391240
sampler_type_names.pop_back();
12401241

12411242

1243+
params.optimize = ggml_opt_get_default_optimizer_params(NULL);
1244+
params.optimize.alpha = 1e-8; // default 1e-3 is much too high for LLAMA_EXAMPLE_FINETUNE
1245+
12421246
/**
12431247
* filter options by example
12441248
* rules:
@@ -1472,14 +1476,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
14721476
[](common_params & params) {
14731477
params.ctx_shift = false;
14741478
}
1475-
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY}).set_env("LLAMA_ARG_NO_CONTEXT_SHIFT"));
1479+
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE}).set_env("LLAMA_ARG_NO_CONTEXT_SHIFT"));
14761480
add_opt(common_arg(
14771481
{"--chunks"}, "N",
14781482
string_format("max number of chunks to process (default: %d, -1 = all)", params.n_chunks),
14791483
[](common_params & params, int value) {
14801484
params.n_chunks = value;
14811485
}
1482-
).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_RETRIEVAL}));
1486+
).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE, LLAMA_EXAMPLE_RETRIEVAL}));
14831487
add_opt(common_arg(
14841488
{"-fa", "--flash-attn"},
14851489
string_format("enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled"),
@@ -2117,70 +2121,88 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
21172121
[](common_params & params) {
21182122
params.hellaswag = true;
21192123
}
2120-
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
2124+
).set_examples({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE}));
21212125
add_opt(common_arg(
21222126
{"--hellaswag-tasks"}, "N",
21232127
string_format("number of tasks to use when computing the HellaSwag score (default: %zu)", params.hellaswag_tasks),
21242128
[](common_params & params, int value) {
21252129
params.hellaswag_tasks = value;
21262130
}
2127-
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
2131+
).set_examples({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE}));
21282132
add_opt(common_arg(
21292133
{"--winogrande"},
21302134
"compute Winogrande score over random tasks from datafile supplied with -f",
21312135
[](common_params & params) {
21322136
params.winogrande = true;
21332137
}
2134-
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
2138+
).set_examples({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE}));
21352139
add_opt(common_arg(
21362140
{"--winogrande-tasks"}, "N",
21372141
string_format("number of tasks to use when computing the Winogrande score (default: %zu)", params.winogrande_tasks),
21382142
[](common_params & params, int value) {
21392143
params.winogrande_tasks = value;
21402144
}
2141-
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
2145+
).set_examples({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE}));
21422146
add_opt(common_arg(
21432147
{"--multiple-choice"},
21442148
"compute multiple choice score over random tasks from datafile supplied with -f",
21452149
[](common_params & params) {
21462150
params.multiple_choice = true;
21472151
}
2148-
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
2152+
).set_examples({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE}));
21492153
add_opt(common_arg(
21502154
{"--multiple-choice-tasks"}, "N",
21512155
string_format("number of tasks to use when computing the multiple choice score (default: %zu)", params.multiple_choice_tasks),
21522156
[](common_params & params, int value) {
21532157
params.multiple_choice_tasks = value;
21542158
}
2155-
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
2159+
).set_examples({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE}));
21562160
add_opt(common_arg(
21572161
{"--kl-divergence"},
21582162
"computes KL-divergence to logits provided via --kl-divergence-base",
21592163
[](common_params & params) {
21602164
params.kl_divergence = true;
21612165
}
2162-
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
2166+
).set_examples({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE}));
21632167
add_opt(common_arg(
21642168
{"--save-all-logits", "--kl-divergence-base"}, "FNAME",
21652169
"set logits file",
21662170
[](common_params & params, const std::string & value) {
21672171
params.logits_file = value;
21682172
}
2169-
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
2173+
).set_examples({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE}));
21702174
add_opt(common_arg(
21712175
{"--ppl-stride"}, "N",
21722176
string_format("stride for perplexity calculation (default: %d)", params.ppl_stride),
21732177
[](common_params & params, int value) {
21742178
params.ppl_stride = value;
21752179
}
2176-
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
2180+
).set_examples({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE}));
21772181
add_opt(common_arg(
21782182
{"--ppl-output-type"}, "<0|1>",
21792183
string_format("output type for perplexity calculation (default: %d)", params.ppl_output_type),
21802184
[](common_params & params, int value) {
21812185
params.ppl_output_type = value;
21822186
}
2183-
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
2187+
).set_examples({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE}));
2188+
add_opt(common_arg(
2189+
{"-lr", "-alpha", "--alpha", "--learning-rate"}, "ALPHA",
2190+
string_format("adamw optimizer alpha (default: %.1f)", (double)params.optimize.adamw.alpha),
2191+
[](common_params & params, const std::string & value) {
2192+
params.optimize.adamw.alpha = std::stof(value);
2193+
}
2194+
).set_examples({LLAMA_EXAMPLE_FINETUNE}));
2195+
add_opt(common_arg(
2196+
{"-opt", "--optimizer"}, "N",
2197+
"adamw (N=0) or //TODO:SGD (N=1)",
2198+
[](common_params & params, int N) {
2199+
if (N == GGML_OPT_OPTIMIZER_SGD)
2200+
throw std::invalid_argument("TODO: implement SGD");
2201+
if (N >= GGML_OPT_OPTIMIZER_COUNT)
2202+
throw std::invalid_argument("invalid --optimizer N (try 0)");
2203+
params.optimize.optimizer = (enum ggml_opt_optimizer)N;
2204+
}
2205+
).set_examples({LLAMA_EXAMPLE_FINETUNE}));
21842206
add_opt(common_arg(
21852207
{"-dt", "--defrag-thold"}, "N",
21862208
string_format("KV cache defragmentation threshold (default: %.1f, < 0 - disabled)", (double)params.defrag_thold),

common/common.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#pragma once
44

55
#include "llama-cpp.h"
6+
#include "ggml-opt.h"
67

78
#include <set>
89
#include <string>
@@ -80,6 +81,7 @@ enum llama_example {
8081
LLAMA_EXAMPLE_LOOKUP,
8182
LLAMA_EXAMPLE_PARALLEL,
8283
LLAMA_EXAMPLE_TTS,
84+
LLAMA_EXAMPLE_FINETUNE,
8385

8486
LLAMA_EXAMPLE_COUNT,
8587
};
@@ -349,6 +351,8 @@ struct common_params {
349351
bool no_mmproj = false; // explicitly disable multimodal model
350352
std::vector<std::string> image; // path to image file(s)
351353

354+
// finetune
355+
struct ggml_opt_optimizer_params optimize;
352356
// embedding
353357
bool embedding = false; // get only sentence embedding
354358
int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)

examples/training/finetune.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ int main(int argc, char ** argv) {
1818

1919
params.escape = false;
2020

21-
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PERPLEXITY)) {
21+
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_FINETUNE)) {
2222
return 1;
2323
}
2424

@@ -60,8 +60,8 @@ int main(int argc, char ** argv) {
6060
std::vector<llama_token> tokens = common_tokenize(ctx.get(), params.prompt, true);
6161
ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get())/2);
6262

63-
struct ggml_opt_optimizer_params optimizer_params = ggml_opt_get_default_optimizer_params(nullptr);
64-
optimizer_params.adamw.alpha = 1e-7f; // learning rate
63+
struct ggml_opt_optimizer_params &optimizer_params = params.optimize;
64+
LOG_INF("-optimizer %d -lr: %.1f", optimizer_params.optimizer, (double)optimizer_params.adamw.alpha);
6565

6666
struct llama_opt_params lopt_params {
6767
/*n_ctx_train =*/ 0,

ggml/include/ggml-opt.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,14 @@ extern "C" {
7474
GGML_OPT_BUILD_TYPE_OPT = 30,
7575
};
7676

77+
enum ggml_opt_optimizer
78+
{
79+
GGML_OPT_OPTIMIZER_ADAMW,
80+
GGML_OPT_OPTIMIZER_SGD,
81+
82+
GGML_OPT_OPTIMIZER_COUNT
83+
};
84+
7785
// parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
7886
struct ggml_opt_optimizer_params {
7987
// AdamW optimizer parameters
@@ -84,6 +92,7 @@ extern "C" {
8492
float eps; // epsilon for numerical stability
8593
float wd; // weight decay for AdamW, use 0.0f to disable
8694
} adamw;
95+
enum ggml_opt_optimizer optimizer;
8796
};
8897

8998
// callback to calculate optimizer parameters prior to a backward pass

ggml/src/ggml-opt.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * us
228228
result.adamw.beta2 = 0.999f;
229229
result.adamw.eps = 1e-8f;
230230
result.adamw.wd = 0.0f;
231+
result.optimizer = GGML_OPT_OPTIMIZER_ADAMW;
231232

232233
return result;
233234
}

0 commit comments

Comments
 (0)