Skip to content

Commit 5cdb27e

Browse files
graehl0cc4mJohannesGaessler
authored
finetune: SGD optimizer, more CLI args (#13873)
* examples/finetune -opt SGD (stochastic gradient descent) memory opt add unit tested GGML_OPT_OPTIMIZER_SGD to ggml - avoids allocating m, v tensors. support finetune.cpp arg -opt SGD (or sgd). (default adamw as before) llama 3.2-1b-F32 result: observed 11gb gpu ram (41 sec/epoch) when using SGD instead of 19gb (55 sec/epoch) using adamw. (wikipedia 100 lines finetune) ( using the same GPU memory, adamw can only do before OOM 512 batch/context, reaching: train: [███████▉] data=0000140/0000140 loss=0.02575±0.00099 acc=99.52±0.03% t=00:00:47 ETA=00:00:00 val: [███████▉] data=0000008/0000008 loss=4.76565±0.28810 acc=41.46±0.77% t=00:00:00 ETA=00:00:00 SGD is superior, though it converges slower, with max before OOM 1728 batch/context (esp see the better validation perf): train: [███████▉] data=0000039/0000039 loss=0.00371±0.00010 acc=99.96±0.01% t=00:00:41 ETA=00:00:00 val: [███████▉] data=0000003/0000003 loss=5.11406±0.76034 acc=48.01±0.69% t=00:00:01 ETA=00:00:00 ) note: when finetuning long enough (or w/ enough -lr), validation accuracy *eventually* drops ('catastrophic forgetting') -lr-half (halflife) option useful for SGD to avoid oscillation or super slow underdamped learning (makes setting -lr more forgiving). terminal -lr for now is set by lr-halvings i.e. if you want at most 1/8 the inital -lr you set -lr-halvings 3. note: objective loss not directly comparable between adamw, sgd? - check perplexity or accuracy or consider relative improvements for convergence new finetune args -wd 1e-9 to enable weight decay in sgd or adamw, and max -epochs N (default 2 as before) cache (1 - wd*alpha) in 'adamw' opt struct - no noticeable perf benefit, disabled (still done for new SGD though) since opt. memory is pre-allocated, the ggml_opt_get_optimizer_params would probably be able to change between SGD and AdamW with each epoch but would need to use adamw for the first (unconfirmed - no cmdline arg to set such a policy yet) test-opt checks adamw as before and now sgd (except for a few disabled tests for sgd only; probably just needs logging values and adding alternate reference values); tolerance on the 'regression' test is broader for sgd (so we don't need many more epochs) * Vulkan: Implement GGML_OP_OPT_STEP_SGD * tests: Fix OPT_STEP_SGD test-backend-ops * SGD op param store weight-decay and not 1-alpha*wd * minor + cosmetic changes * fix vulkan sgd * try CI fix --------- Co-authored-by: 0cc4m <[email protected]> Co-authored-by: Johannes Gäßler <[email protected]>
1 parent 3ea913f commit 5cdb27e

24 files changed

+727
-196
lines changed

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE)
1212
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
1313
endif()
1414

15+
message("CMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}")
16+
1517
# Add path to modules
1618
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
1719

common/arg.cpp

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1238,6 +1238,7 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e
12381238
common_params_print_completion(ctx_arg);
12391239
exit(0);
12401240
}
1241+
params.lr.init();
12411242
} catch (const std::invalid_argument & ex) {
12421243
fprintf(stderr, "%s\n", ex.what());
12431244
ctx_arg.params = params_org;
@@ -2688,7 +2689,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
26882689
[](common_params & params, const std::string & value) {
26892690
params.out_file = value;
26902691
}
2691-
).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_CVECTOR_GENERATOR, LLAMA_EXAMPLE_EXPORT_LORA, LLAMA_EXAMPLE_TTS}));
2692+
).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_CVECTOR_GENERATOR, LLAMA_EXAMPLE_EXPORT_LORA, LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_FINETUNE}));
26922693
add_opt(common_arg(
26932694
{"-ofreq", "--output-frequency"}, "N",
26942695
string_format("output the imatrix every N iterations (default: %d)", params.n_out_freq),
@@ -3566,5 +3567,51 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
35663567
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
35673568

35683569

3570+
add_opt(
3571+
common_arg({ "-lr", "--learning-rate" }, "ALPHA",
3572+
string_format(
3573+
"adamw or sgd optimizer alpha (default: %.2g); note: sgd alpha recommended ~10x (no momentum)",
3574+
(double) params.lr.lr0),
3575+
[](common_params & params, const std::string & value) { params.lr.lr0 = std::stof(value); })
3576+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
3577+
add_opt(
3578+
common_arg({ "-lr-min", "--learning-rate-min" }, "ALPHA",
3579+
string_format(
3580+
"(if >0) final learning rate after decay (if -decay-epochs is set, default=%.2g)",
3581+
(double) params.lr.lr_min),
3582+
[](common_params & params, const std::string & value) { params.lr.lr_min = std::stof(value); })
3583+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
3584+
add_opt(
3585+
common_arg({ "-decay-epochs", "--learning-rate-decay-epochs" }, "ALPHA",
3586+
string_format(
3587+
"(if >0) decay learning rate to -lr-min after this many epochs (exponential decay, default=%.2g)",
3588+
(double) params.lr.decay_epochs),
3589+
[](common_params & params, const std::string & value) { params.lr.decay_epochs = std::stof(value); })
3590+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
3591+
add_opt(common_arg(
3592+
{ "-wd", "--weight-decay" }, "WD",
3593+
string_format(
3594+
"adamw or sgd optimizer weight decay (0 is off; recommend very small e.g. 1e-9) (default: %.2g).",
3595+
(double) params.lr.wd),
3596+
[](common_params & params, const std::string & value) { params.lr.wd = std::stof(value); })
3597+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
3598+
add_opt(common_arg({ "-val-split", "--val-split" }, "FRACTION",
3599+
string_format("fraction of data to use as validation set for training (default: %.2g).",
3600+
(double) params.val_split),
3601+
[](common_params & params, const std::string & value) { params.val_split = std::stof(value); })
3602+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
3603+
add_opt(common_arg({ "-epochs", "--epochs" }, "N",
3604+
string_format("optimizer max # of epochs (default: %d)", params.lr.epochs),
3605+
[](common_params & params, int epochs) { params.lr.epochs = epochs; })
3606+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
3607+
add_opt(common_arg({ "-opt", "--optimizer" }, "sgd|adamw", "adamw or sgd",
3608+
[](common_params & params, const std::string & name) {
3609+
params.optimizer = common_opt_get_optimizer(name.c_str());
3610+
if (params.optimizer == GGML_OPT_OPTIMIZER_TYPE_COUNT) {
3611+
throw std::invalid_argument("invalid --optimizer, valid options: adamw, sgd");
3612+
}
3613+
})
3614+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
3615+
35693616
return ctx_arg;
35703617
}

common/common.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
#endif
4242
#include <locale>
4343
#include <windows.h>
44+
#include <string.h>
4445
#include <fcntl.h>
4546
#include <io.h>
4647
#else
@@ -1565,3 +1566,56 @@ ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std
15651566

15661567
return result;
15671568
}
1569+
1570+
ggml_opt_optimizer_params common_opt_lr_pars(void * userdata) {
1571+
ggml_opt_optimizer_params result = ggml_opt_get_default_optimizer_params(nullptr);
1572+
const lr_opt & d = *(lr_opt *) userdata;
1573+
result.adamw.alpha = result.sgd.alpha = d.get_lr(d.epoch);
1574+
result.sgd.wd = result.adamw.wd = d.wd;
1575+
return result;
1576+
}
1577+
1578+
// TODO make all command line args case-insensitive
1579+
static inline bool eq_case_insensitive(char const* a, char const* b) {
1580+
return !
1581+
#if defined(_MSC_VER)
1582+
_stricmp
1583+
#else
1584+
strcasecmp
1585+
#endif // defined(_MSC_VER)
1586+
(a, b);
1587+
}
1588+
1589+
enum ggml_opt_optimizer_type common_opt_get_optimizer(const char * n) {
1590+
if (eq_case_insensitive("adamw", n)) {
1591+
return GGML_OPT_OPTIMIZER_TYPE_ADAMW;
1592+
}
1593+
if (eq_case_insensitive("sgd", n)) {
1594+
return GGML_OPT_OPTIMIZER_TYPE_SGD;
1595+
}
1596+
return GGML_OPT_OPTIMIZER_TYPE_COUNT;
1597+
}
1598+
1599+
// TODO simplify to use just log and exp
1600+
static float const k_log_2 = std::log(2.f);
1601+
1602+
void lr_opt::init() {
1603+
if (lr_min > 0 && lr_min < lr0) {
1604+
float nhalf = std::log(lr0 / lr_min) / k_log_2;
1605+
float e = epochs;
1606+
if (decay_epochs > 0 && decay_epochs < e) {
1607+
e = decay_epochs;
1608+
} else {
1609+
decay_epochs = e;
1610+
}
1611+
scale_epoch = nhalf / e;
1612+
}
1613+
}
1614+
1615+
float lr_opt::get_lr(float epoch) const {
1616+
float r = lr_min <= 0 ? lr0 :
1617+
epoch >= decay_epochs ? lr_min :
1618+
lr0 * std::pow(0.5f, epoch * scale_epoch);
1619+
LOG_INF("epoch %.2g lr=%.2g\n", epoch, r);
1620+
return r;
1621+
}

common/common.h

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,17 @@
22

33
#pragma once
44

5-
#include "llama-cpp.h"
6-
75
#include <set>
6+
#include <sstream>
87
#include <string>
98
#include <string_view>
109
#include <vector>
1110
#include <map>
1211
#include <sstream>
12+
#include <cmath>
13+
14+
#include "ggml-opt.h"
15+
#include "llama-cpp.h"
1316

1417
#ifdef _WIN32
1518
#define DIRECTORY_SEPARATOR '\\'
@@ -82,6 +85,7 @@ enum llama_example {
8285
LLAMA_EXAMPLE_PARALLEL,
8386
LLAMA_EXAMPLE_TTS,
8487
LLAMA_EXAMPLE_DIFFUSION,
88+
LLAMA_EXAMPLE_FINETUNE,
8589

8690
LLAMA_EXAMPLE_COUNT,
8791
};
@@ -243,6 +247,25 @@ enum common_reasoning_format {
243247
COMMON_REASONING_FORMAT_GRANITE, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
244248
};
245249

250+
251+
struct lr_opt {
252+
float lr0 = 1e-5; // learning rate at first epoch
253+
float lr_min = -1;
254+
float decay_epochs = -1; // if >0, the learning rate starts at lr0 and decays to lr_min after this many epochs
255+
float scale_epoch = 0;
256+
float wd = 0;
257+
unsigned epochs = 2;
258+
259+
unsigned epoch; // set by optimizer outer (epochs) loop
260+
// learning rate decay - constant LR per epoch only for now
261+
float get_lr(float e) const;
262+
float get_lr() const { return get_lr(epoch); }
263+
// must call after arg parse, before get_lr
264+
void init();
265+
};
266+
267+
struct ggml_opt_optimizer_params common_opt_lr_pars(void * userdata);
268+
246269
struct common_params {
247270
int32_t n_predict = -1; // new tokens to predict
248271
int32_t n_ctx = 4096; // context size
@@ -377,6 +400,11 @@ struct common_params {
377400
bool no_mmproj = false; // explicitly disable multimodal model
378401
std::vector<std::string> image; // path to image file(s)
379402

403+
// finetune
404+
struct lr_opt lr;
405+
enum ggml_opt_optimizer_type optimizer = GGML_OPT_OPTIMIZER_TYPE_ADAMW;
406+
float val_split = 0.05f; // fraction of the data used for the validation set
407+
380408
// embedding
381409
bool embedding = false; // get only sentence embedding
382410
int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
@@ -704,3 +732,6 @@ const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
704732
//
705733

706734
ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride);
735+
736+
// "adamw" or "sgd" (case insensitive)
737+
enum ggml_opt_optimizer_type common_opt_get_optimizer(const char *);

examples/training/finetune.cpp

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,20 @@
1010
#include <vector>
1111

1212
#if defined(_MSC_VER)
13-
#pragma warning(disable: 4244 4267) // possible loss of data
13+
#pragma warning(disable: 4244 4267) // possible loss of data
1414
#endif
1515

1616
int main(int argc, char ** argv) {
1717
common_params params;
18-
1918
params.escape = false;
2019

21-
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PERPLEXITY)) {
20+
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_FINETUNE)) {
2221
return 1;
2322
}
2423

2524
if (params.use_mmap) {
26-
LOG_INF("%s: force disabling memory mapping because it would result in-read-only pointers to the weights\n", __func__);
25+
LOG_INF("%s: force disabling memory mapping because it would result in-read-only pointers to the weights\n",
26+
__func__);
2727
params.use_mmap = false;
2828
}
2929
if (params.cache_type_k != GGML_TYPE_F32) {
@@ -38,11 +38,10 @@ int main(int argc, char ** argv) {
3838
common_init();
3939
llama_backend_init();
4040
llama_numa_init(params.numa);
41-
4241
// load the model and apply lora adapter, if any
43-
common_init_result llama_init = common_init_from_params(params);
44-
llama_model_ptr & model = llama_init.model;
45-
llama_context_ptr & ctx = llama_init.context;
42+
common_init_result llama_init = common_init_from_params(params);
43+
llama_model_ptr & model = llama_init.model;
44+
llama_context_ptr & ctx = llama_init.context;
4645

4746
if (model == NULL) {
4847
LOG_ERR("%s: unable to load model\n", __func__);
@@ -55,31 +54,32 @@ int main(int argc, char ** argv) {
5554
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
5655
}
5756

58-
constexpr float val_split = 0.05f;
59-
60-
std::vector<llama_token> tokens = common_tokenize(ctx.get(), params.prompt, true);
61-
ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get())/2);
62-
63-
struct ggml_opt_optimizer_params optimizer_params = ggml_opt_get_default_optimizer_params(nullptr);
64-
optimizer_params.adamw.alpha = 1e-7f; // learning rate
65-
66-
struct llama_opt_params lopt_params {
67-
/*n_ctx_train =*/ 0,
68-
/*param_filter =*/ llama_opt_param_filter_all,
69-
/*param_filter_ud =*/ nullptr,
70-
/*get_opt_pars =*/ ggml_opt_get_constant_optimizer_params,
71-
/*get_opt_pars_ud =*/ &optimizer_params,
57+
std::vector<llama_token> tokens = common_tokenize(ctx.get(), params.prompt, true);
58+
ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get()) / 2);
59+
60+
struct lr_opt & lr = params.lr;
61+
LOG_INF("-optimizer %s -lr0 %.2g -wd %.2g -lr-min %.2g -min-epochs %.2g -epochs %d -period %.2g -val %.2g\n",
62+
ggml_opt_optimizer_name(params.optimizer), (double) lr.lr0, (double) lr.wd, (double) lr.lr_min, (double) lr.decay_epochs,
63+
(unsigned) lr.epochs, (double) params.n_batch / params.n_ubatch, (double) params.val_split);
64+
65+
struct llama_opt_params lopt_params{
66+
/*n_ctx_train =*/0,
67+
/*param_filter =*/llama_opt_param_filter_all,
68+
/*param_filter_ud =*/nullptr,
69+
/*get_opt_pars =*/common_opt_lr_pars,
70+
/*get_opt_pars_ud =*/&params.lr,
71+
/*optimizer_type =*/params.optimizer,
7272
};
7373
llama_opt_init(ctx.get(), model.get(), lopt_params);
7474

75-
const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - val_split);
75+
const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - params.val_split);
7676

7777
ggml_opt_result_t result_train = ggml_opt_result_init();
7878
ggml_opt_result_t result_eval = ggml_opt_result_init();
7979

80-
for (int epoch = 0; epoch < 2; ++epoch) {
80+
for (lr.epoch = 0; lr.epoch < lr.epochs; ++lr.epoch) {
8181
llama_opt_epoch(ctx.get(), dataset, result_train, result_eval, idata_split,
82-
ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar);
82+
ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar);
8383
fprintf(stderr, "\n");
8484

8585
ggml_opt_result_reset(result_train);
@@ -88,7 +88,7 @@ int main(int argc, char ** argv) {
8888
ggml_opt_result_free(result_train);
8989
ggml_opt_result_free(result_eval);
9090

91-
llama_model_save_to_file(model.get(), "finetuned-model.gguf");
91+
llama_model_save_to_file(model.get(), params.out_file.c_str());
9292

9393
llama_backend_free();
9494

ggml/include/ggml-opt.h

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,16 +74,26 @@ extern "C" {
7474
GGML_OPT_BUILD_TYPE_OPT = 30,
7575
};
7676

77+
enum ggml_opt_optimizer_type {
78+
GGML_OPT_OPTIMIZER_TYPE_ADAMW,
79+
GGML_OPT_OPTIMIZER_TYPE_SGD,
80+
81+
GGML_OPT_OPTIMIZER_TYPE_COUNT
82+
};
83+
7784
// parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
7885
struct ggml_opt_optimizer_params {
79-
// AdamW optimizer parameters
8086
struct {
8187
float alpha; // learning rate
82-
float beta1;
83-
float beta2;
88+
float beta1; // first AdamW momentum
89+
float beta2; // second AdamW momentum
8490
float eps; // epsilon for numerical stability
85-
float wd; // weight decay for AdamW, use 0.0f to disable
91+
float wd; // weight decay - 0.0f to disable
8692
} adamw;
93+
struct {
94+
float alpha; // learning rate
95+
float wd; // weight decay
96+
} sgd;
8797
};
8898

8999
// callback to calculate optimizer parameters prior to a backward pass
@@ -112,8 +122,11 @@ extern "C" {
112122

113123
int32_t opt_period; // after how many gradient accumulation steps an optimizer step should be done
114124

115-
ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
116-
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
125+
ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
126+
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
127+
128+
// only GGML_OPT_OPTIMIZER_TYPE_ADAMW needs m, v momenta per parameter tensor
129+
enum ggml_opt_optimizer_type optimizer;
117130
};
118131

119132
// get parameters for an optimization context with defaults set where possible
@@ -142,6 +155,10 @@ extern "C" {
142155
// get the gradient accumulator for a node from the forward graph
143156
GGML_API struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node);
144157

158+
GGML_API enum ggml_opt_optimizer_type ggml_opt_context_optimizer_type(ggml_opt_context_t); //TODO consistent naming scheme
159+
160+
GGML_API const char * ggml_opt_optimizer_name(enum ggml_opt_optimizer_type);
161+
145162
// ====== Optimization Result ======
146163

147164
GGML_API ggml_opt_result_t ggml_opt_result_init(void);
@@ -226,12 +243,14 @@ extern "C" {
226243
struct ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
227244
ggml_opt_dataset_t dataset, // dataset with data and optionally also labels
228245
enum ggml_opt_loss_type loss_type, // loss to minimize
246+
enum ggml_opt_optimizer_type optimizer, // sgd or adamw
229247
ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t)
230248
int64_t nepoch, // how many times the dataset should be iterated over
231249
int64_t nbatch_logical, // datapoints optimizer step, must be a multiple of ndata_batch in inputs/outputs
232250
float val_split, // fraction of the dataset to use for validation, must be in [0.0f, 1.0f)
233251
bool silent); // whether or not info prints to stderr should be suppressed
234252

253+
235254
#ifdef __cplusplus
236255
}
237256
#endif

0 commit comments

Comments
 (0)