Skip to content

Commit 0fa3543

Browse files
authored
Merge branch 'ggml-org:master' into master
2 parents 092fd84 + d32e03f commit 0fa3543

37 files changed

+942
-255
lines changed

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE)
1212
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
1313
endif()
1414

15+
message("CMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}")
16+
1517
# Add path to modules
1618
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
1719

common/arg.cpp

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1238,6 +1238,7 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e
12381238
common_params_print_completion(ctx_arg);
12391239
exit(0);
12401240
}
1241+
params.lr.init();
12411242
} catch (const std::invalid_argument & ex) {
12421243
fprintf(stderr, "%s\n", ex.what());
12431244
ctx_arg.params = params_org;
@@ -1506,6 +1507,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
15061507
params.swa_full = true;
15071508
}
15081509
).set_env("LLAMA_ARG_SWA_FULL"));
1510+
add_opt(common_arg(
1511+
{"--swa-checkpoints"}, "N",
1512+
string_format("max number of SWA checkpoints per slot to create (default: %d)\n"
1513+
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)", params.n_swa_checkpoints),
1514+
[](common_params & params, int value) {
1515+
params.n_swa_checkpoints = value;
1516+
}
1517+
).set_env("LLAMA_ARG_SWA_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER}));
15091518
add_opt(common_arg(
15101519
{"--kv-unified", "-kvu"},
15111520
string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n"
@@ -2688,7 +2697,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
26882697
[](common_params & params, const std::string & value) {
26892698
params.out_file = value;
26902699
}
2691-
).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_CVECTOR_GENERATOR, LLAMA_EXAMPLE_EXPORT_LORA, LLAMA_EXAMPLE_TTS}));
2700+
).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_CVECTOR_GENERATOR, LLAMA_EXAMPLE_EXPORT_LORA, LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_FINETUNE}));
26922701
add_opt(common_arg(
26932702
{"-ofreq", "--output-frequency"}, "N",
26942703
string_format("output the imatrix every N iterations (default: %d)", params.n_out_freq),
@@ -3566,5 +3575,51 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
35663575
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
35673576

35683577

3578+
add_opt(
3579+
common_arg({ "-lr", "--learning-rate" }, "ALPHA",
3580+
string_format(
3581+
"adamw or sgd optimizer alpha (default: %.2g); note: sgd alpha recommended ~10x (no momentum)",
3582+
(double) params.lr.lr0),
3583+
[](common_params & params, const std::string & value) { params.lr.lr0 = std::stof(value); })
3584+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
3585+
add_opt(
3586+
common_arg({ "-lr-min", "--learning-rate-min" }, "ALPHA",
3587+
string_format(
3588+
"(if >0) final learning rate after decay (if -decay-epochs is set, default=%.2g)",
3589+
(double) params.lr.lr_min),
3590+
[](common_params & params, const std::string & value) { params.lr.lr_min = std::stof(value); })
3591+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
3592+
add_opt(
3593+
common_arg({ "-decay-epochs", "--learning-rate-decay-epochs" }, "ALPHA",
3594+
string_format(
3595+
"(if >0) decay learning rate to -lr-min after this many epochs (exponential decay, default=%.2g)",
3596+
(double) params.lr.decay_epochs),
3597+
[](common_params & params, const std::string & value) { params.lr.decay_epochs = std::stof(value); })
3598+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
3599+
add_opt(common_arg(
3600+
{ "-wd", "--weight-decay" }, "WD",
3601+
string_format(
3602+
"adamw or sgd optimizer weight decay (0 is off; recommend very small e.g. 1e-9) (default: %.2g).",
3603+
(double) params.lr.wd),
3604+
[](common_params & params, const std::string & value) { params.lr.wd = std::stof(value); })
3605+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
3606+
add_opt(common_arg({ "-val-split", "--val-split" }, "FRACTION",
3607+
string_format("fraction of data to use as validation set for training (default: %.2g).",
3608+
(double) params.val_split),
3609+
[](common_params & params, const std::string & value) { params.val_split = std::stof(value); })
3610+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
3611+
add_opt(common_arg({ "-epochs", "--epochs" }, "N",
3612+
string_format("optimizer max # of epochs (default: %d)", params.lr.epochs),
3613+
[](common_params & params, int epochs) { params.lr.epochs = epochs; })
3614+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
3615+
add_opt(common_arg({ "-opt", "--optimizer" }, "sgd|adamw", "adamw or sgd",
3616+
[](common_params & params, const std::string & name) {
3617+
params.optimizer = common_opt_get_optimizer(name.c_str());
3618+
if (params.optimizer == GGML_OPT_OPTIMIZER_TYPE_COUNT) {
3619+
throw std::invalid_argument("invalid --optimizer, valid options: adamw, sgd");
3620+
}
3621+
})
3622+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
3623+
35693624
return ctx_arg;
35703625
}

common/common.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
#endif
4242
#include <locale>
4343
#include <windows.h>
44+
#include <string.h>
4445
#include <fcntl.h>
4546
#include <io.h>
4647
#else
@@ -1565,3 +1566,56 @@ ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std
15651566

15661567
return result;
15671568
}
1569+
1570+
ggml_opt_optimizer_params common_opt_lr_pars(void * userdata) {
1571+
ggml_opt_optimizer_params result = ggml_opt_get_default_optimizer_params(nullptr);
1572+
const lr_opt & d = *(lr_opt *) userdata;
1573+
result.adamw.alpha = result.sgd.alpha = d.get_lr(d.epoch);
1574+
result.sgd.wd = result.adamw.wd = d.wd;
1575+
return result;
1576+
}
1577+
1578+
// TODO make all command line args case-insensitive
1579+
static inline bool eq_case_insensitive(char const* a, char const* b) {
1580+
return !
1581+
#if defined(_MSC_VER)
1582+
_stricmp
1583+
#else
1584+
strcasecmp
1585+
#endif // defined(_MSC_VER)
1586+
(a, b);
1587+
}
1588+
1589+
enum ggml_opt_optimizer_type common_opt_get_optimizer(const char * n) {
1590+
if (eq_case_insensitive("adamw", n)) {
1591+
return GGML_OPT_OPTIMIZER_TYPE_ADAMW;
1592+
}
1593+
if (eq_case_insensitive("sgd", n)) {
1594+
return GGML_OPT_OPTIMIZER_TYPE_SGD;
1595+
}
1596+
return GGML_OPT_OPTIMIZER_TYPE_COUNT;
1597+
}
1598+
1599+
// TODO simplify to use just log and exp
1600+
static float const k_log_2 = std::log(2.f);
1601+
1602+
void lr_opt::init() {
1603+
if (lr_min > 0 && lr_min < lr0) {
1604+
float nhalf = std::log(lr0 / lr_min) / k_log_2;
1605+
float e = epochs;
1606+
if (decay_epochs > 0 && decay_epochs < e) {
1607+
e = decay_epochs;
1608+
} else {
1609+
decay_epochs = e;
1610+
}
1611+
scale_epoch = nhalf / e;
1612+
}
1613+
}
1614+
1615+
float lr_opt::get_lr(float epoch) const {
1616+
float r = lr_min <= 0 ? lr0 :
1617+
epoch >= decay_epochs ? lr_min :
1618+
lr0 * std::pow(0.5f, epoch * scale_epoch);
1619+
LOG_INF("epoch %.2g lr=%.2g\n", epoch, r);
1620+
return r;
1621+
}

common/common.h

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,17 @@
22

33
#pragma once
44

5-
#include "llama-cpp.h"
6-
75
#include <set>
6+
#include <sstream>
87
#include <string>
98
#include <string_view>
109
#include <vector>
1110
#include <map>
1211
#include <sstream>
12+
#include <cmath>
13+
14+
#include "ggml-opt.h"
15+
#include "llama-cpp.h"
1316

1417
#ifdef _WIN32
1518
#define DIRECTORY_SEPARATOR '\\'
@@ -82,6 +85,7 @@ enum llama_example {
8285
LLAMA_EXAMPLE_PARALLEL,
8386
LLAMA_EXAMPLE_TTS,
8487
LLAMA_EXAMPLE_DIFFUSION,
88+
LLAMA_EXAMPLE_FINETUNE,
8589

8690
LLAMA_EXAMPLE_COUNT,
8791
};
@@ -243,6 +247,25 @@ enum common_reasoning_format {
243247
COMMON_REASONING_FORMAT_GRANITE, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
244248
};
245249

250+
251+
struct lr_opt {
252+
float lr0 = 1e-5; // learning rate at first epoch
253+
float lr_min = -1;
254+
float decay_epochs = -1; // if >0, the learning rate starts at lr0 and decays to lr_min after this many epochs
255+
float scale_epoch = 0;
256+
float wd = 0;
257+
unsigned epochs = 2;
258+
259+
unsigned epoch; // set by optimizer outer (epochs) loop
260+
// learning rate decay - constant LR per epoch only for now
261+
float get_lr(float e) const;
262+
float get_lr() const { return get_lr(epoch); }
263+
// must call after arg parse, before get_lr
264+
void init();
265+
};
266+
267+
struct ggml_opt_optimizer_params common_opt_lr_pars(void * userdata);
268+
246269
struct common_params {
247270
int32_t n_predict = -1; // new tokens to predict
248271
int32_t n_ctx = 4096; // context size
@@ -377,6 +400,11 @@ struct common_params {
377400
bool no_mmproj = false; // explicitly disable multimodal model
378401
std::vector<std::string> image; // path to image file(s)
379402

403+
// finetune
404+
struct lr_opt lr;
405+
enum ggml_opt_optimizer_type optimizer = GGML_OPT_OPTIMIZER_TYPE_ADAMW;
406+
float val_split = 0.05f; // fraction of the data used for the validation set
407+
380408
// embedding
381409
bool embedding = false; // get only sentence embedding
382410
int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
@@ -385,11 +413,12 @@ struct common_params {
385413
std::string cls_sep = "\t"; // separator of classification sequences
386414

387415
// server params
388-
int32_t port = 8080; // server listens on this network port
389-
int32_t timeout_read = 600; // http read timeout in seconds
390-
int32_t timeout_write = timeout_read; // http write timeout in seconds
391-
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
392-
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
416+
int32_t port = 8080; // server listens on this network port
417+
int32_t timeout_read = 600; // http read timeout in seconds
418+
int32_t timeout_write = timeout_read; // http write timeout in seconds
419+
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
420+
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
421+
int32_t n_swa_checkpoints = 3; // max number of SWA checkpoints per slot
393422

394423
std::string hostname = "127.0.0.1";
395424
std::string public_path = ""; // NOLINT
@@ -704,3 +733,6 @@ const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
704733
//
705734

706735
ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride);
736+
737+
// "adamw" or "sgd" (case insensitive)
738+
enum ggml_opt_optimizer_type common_opt_get_optimizer(const char *);

examples/training/finetune.cpp

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,20 @@
1010
#include <vector>
1111

1212
#if defined(_MSC_VER)
13-
#pragma warning(disable: 4244 4267) // possible loss of data
13+
#pragma warning(disable: 4244 4267) // possible loss of data
1414
#endif
1515

1616
int main(int argc, char ** argv) {
1717
common_params params;
18-
1918
params.escape = false;
2019

21-
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PERPLEXITY)) {
20+
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_FINETUNE)) {
2221
return 1;
2322
}
2423

2524
if (params.use_mmap) {
26-
LOG_INF("%s: force disabling memory mapping because it would result in-read-only pointers to the weights\n", __func__);
25+
LOG_INF("%s: force disabling memory mapping because it would result in-read-only pointers to the weights\n",
26+
__func__);
2727
params.use_mmap = false;
2828
}
2929
if (params.cache_type_k != GGML_TYPE_F32) {
@@ -38,11 +38,10 @@ int main(int argc, char ** argv) {
3838
common_init();
3939
llama_backend_init();
4040
llama_numa_init(params.numa);
41-
4241
// load the model and apply lora adapter, if any
43-
common_init_result llama_init = common_init_from_params(params);
44-
llama_model_ptr & model = llama_init.model;
45-
llama_context_ptr & ctx = llama_init.context;
42+
common_init_result llama_init = common_init_from_params(params);
43+
llama_model_ptr & model = llama_init.model;
44+
llama_context_ptr & ctx = llama_init.context;
4645

4746
if (model == NULL) {
4847
LOG_ERR("%s: unable to load model\n", __func__);
@@ -55,31 +54,32 @@ int main(int argc, char ** argv) {
5554
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
5655
}
5756

58-
constexpr float val_split = 0.05f;
59-
60-
std::vector<llama_token> tokens = common_tokenize(ctx.get(), params.prompt, true);
61-
ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get())/2);
62-
63-
struct ggml_opt_optimizer_params optimizer_params = ggml_opt_get_default_optimizer_params(nullptr);
64-
optimizer_params.adamw.alpha = 1e-7f; // learning rate
65-
66-
struct llama_opt_params lopt_params {
67-
/*n_ctx_train =*/ 0,
68-
/*param_filter =*/ llama_opt_param_filter_all,
69-
/*param_filter_ud =*/ nullptr,
70-
/*get_opt_pars =*/ ggml_opt_get_constant_optimizer_params,
71-
/*get_opt_pars_ud =*/ &optimizer_params,
57+
std::vector<llama_token> tokens = common_tokenize(ctx.get(), params.prompt, true);
58+
ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get()) / 2);
59+
60+
struct lr_opt & lr = params.lr;
61+
LOG_INF("-optimizer %s -lr0 %.2g -wd %.2g -lr-min %.2g -min-epochs %.2g -epochs %d -period %.2g -val %.2g\n",
62+
ggml_opt_optimizer_name(params.optimizer), (double) lr.lr0, (double) lr.wd, (double) lr.lr_min, (double) lr.decay_epochs,
63+
(unsigned) lr.epochs, (double) params.n_batch / params.n_ubatch, (double) params.val_split);
64+
65+
struct llama_opt_params lopt_params{
66+
/*n_ctx_train =*/0,
67+
/*param_filter =*/llama_opt_param_filter_all,
68+
/*param_filter_ud =*/nullptr,
69+
/*get_opt_pars =*/common_opt_lr_pars,
70+
/*get_opt_pars_ud =*/&params.lr,
71+
/*optimizer_type =*/params.optimizer,
7272
};
7373
llama_opt_init(ctx.get(), model.get(), lopt_params);
7474

75-
const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - val_split);
75+
const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - params.val_split);
7676

7777
ggml_opt_result_t result_train = ggml_opt_result_init();
7878
ggml_opt_result_t result_eval = ggml_opt_result_init();
7979

80-
for (int epoch = 0; epoch < 2; ++epoch) {
80+
for (lr.epoch = 0; lr.epoch < lr.epochs; ++lr.epoch) {
8181
llama_opt_epoch(ctx.get(), dataset, result_train, result_eval, idata_split,
82-
ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar);
82+
ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar);
8383
fprintf(stderr, "\n");
8484

8585
ggml_opt_result_reset(result_train);
@@ -88,7 +88,7 @@ int main(int argc, char ** argv) {
8888
ggml_opt_result_free(result_train);
8989
ggml_opt_result_free(result_eval);
9090

91-
llama_model_save_to_file(model.get(), "finetuned-model.gguf");
91+
llama_model_save_to_file(model.get(), params.out_file.c_str());
9292

9393
llama_backend_free();
9494

0 commit comments

Comments
 (0)