Skip to content

Commit 8a1eb16

Browse files
committed
examples/finetune -opt SGD (stochastic gradient descent) memory opt
support finetune arg -opt SGD (or sgd). llama 3.2-1b-F32 result: observed 11gb gpu ram (45 sec/epoch) when using SGD instead of 19gb (55 sec/epoch) using adamw. (getting the right learning rate for SGD is trickier than for adamw - too high and you overshoot+oscillate, too low and you waste compute slowly approaching convergence) SGD (or adamw) quickly reach 99%+ train accuracy. note: objective loss not directly comparable between adamw, sgd? - check perplexity or accuracy or consider relative improvements for convergence also, note that logical batch size > physical batch (gradient accumulation) seems unsupported for optimization (limited to physical , unlike in ppx - also limited to ctx-size). training quality/convergence could be improved by implementing (at cost of some memory, but you can make that up by using a much smaller physical batch for a net memory savings). presumably it's physical batch that should be limited to ctx-size? see llama_context::opt_epoch new finetune args -wd 1e-9 to enable weight decay in sgd or adamw, and max -epochs N (default 2 as before) cache (1 - wd*alpha) in 'adamw' opt struct - no noticeable perf benefit cache computed per-epoch optimizer opts (formerly were computed twice per) add unit tested GGML_OPT_OPTIMIZER_SGD to ggml - avoids allocating m, v tensors. make ggml_opt_init aware of the optimization method since opt. memory is pre-allocated, the ggml_opt_get_optimizer_params would probably be able to change between SGD and AdamW with each epoch but would need to use adamw for the first (unconfirmed - no arg to set such a policy yet) 100 lines of wikipedia train: train: ... loss=0.00231±0.00032 acc=99.99±0.01% t=00:00:05 val: ... loss=3.91926±nan acc=58.40±2.18% on more training data (500 lines), additional catastrophic forgetting before train reaches 99.9% accuracy: train: data=0000140/0000140 loss=0.02611±0.00077 acc=99.82±0.02% t=00:00:45 val: data=0000008/0000008 loss=4.11112±0.22526 acc=46.36±0.78% increasing batch+ctx sizes to 1536 (double what fits in memory for adamw) gets apparently better validation but that could be an artifact of continuing training from previous weights, i.e. what's train vs val probably depends on batch size. also amusing - faster due to larger batch even though larger context would be slower?: train: data=0000045/0000045 loss=0.01722±0.00103 acc=99.90±0.01% t=00:00:40 val: data=0000003/0000003 loss=1.96829±1.09488 acc=72.44±0.66%
1 parent 5787b5d commit 8a1eb16

File tree

21 files changed

+545
-159
lines changed

21 files changed

+545
-159
lines changed

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE)
1212
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
1313
endif()
1414

15+
message("CMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}")
16+
1517
# Add path to modules
1618
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
1719

common/arg.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1181,6 +1181,7 @@ static void add_rpc_devices(std::string servers) {
11811181
}
11821182

11831183
bool common_params_parse(int argc, char ** argv, common_params & params, llama_example ex, void(*print_usage)(int, char **)) {
1184+
params.optimize = ggml_opt_get_default_optimizer_params(nullptr);
11841185
auto ctx_arg = common_params_parser_init(params, ex, print_usage);
11851186
const common_params params_org = ctx_arg.params; // the example can modify the default params
11861187

@@ -3376,5 +3377,41 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
33763377
}
33773378
).set_examples({LLAMA_EXAMPLE_SERVER}));
33783379

3380+
add_opt(
3381+
common_arg(
3382+
{ "-lr", "--learning-rate" }, "ALPHA",
3383+
string_format(
3384+
"adamw or sgd optimizer alpha (default: %.2g); note: sgd alpha recommended ~100x (no momentum)",
3385+
(double) params.optimize.adamw.alpha),
3386+
[](common_params & params, const std::string & value) { params.optimize.adamw.alpha = std::stof(value); })
3387+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
3388+
add_opt(common_arg(
3389+
{ "-wd", "--weight-decay" }, "WD",
3390+
string_format(
3391+
"adamw or sgd optimizer weight decay (0 is off; recommend very small e.g. 1e-9) (default: %.2g).",
3392+
(double) params.optimize.adamw.wd),
3393+
[](common_params & params, const std::string & value) { params.optimize.adamw.wd = std::stof(value); })
3394+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
3395+
add_opt(common_arg(
3396+
{ "-val", "--val-split" }, "FRACTION",
3397+
string_format(
3398+
"portion of data to use as validation when optimizing (default: %.2g).",
3399+
(double) params.val_split),
3400+
[](common_params & params, const std::string & value) { params.val_split = std::stof(value); })
3401+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
3402+
add_opt(common_arg({ "-epochs", "--epochs" }, "N",
3403+
string_format("optimizer max # of epochs (default: %d)", params.epochs),
3404+
[](common_params & params, int epochs) { params.epochs = epochs; })
3405+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
3406+
add_opt(common_arg({ "-opt", "--optimizer" },
3407+
"sgd|adamw",
3408+
"adamw or sgd",
3409+
[](common_params & params, const std::string & name) {
3410+
params.optimizer = ggml_opt_get_optimizer(name.c_str());
3411+
if (params.optimizer == GGML_OPT_OPTIMIZER_TYPE_COUNT) {
3412+
throw std::invalid_argument("invalid --optimizer, valid options: adamw, sgd");
3413+
}
3414+
}).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
3415+
33793416
return ctx_arg;
33803417
}

common/common.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22

33
#pragma once
44

5-
#include "llama-cpp.h"
6-
75
#include <set>
6+
#include <sstream>
87
#include <string>
98
#include <string_view>
109
#include <vector>
11-
#include <sstream>
10+
11+
#include "ggml-opt.h"
12+
#include "llama-cpp.h"
1213

1314
#ifdef _WIN32
1415
#define DIRECTORY_SEPARATOR '\\'
@@ -80,6 +81,7 @@ enum llama_example {
8081
LLAMA_EXAMPLE_LOOKUP,
8182
LLAMA_EXAMPLE_PARALLEL,
8283
LLAMA_EXAMPLE_TTS,
84+
LLAMA_EXAMPLE_FINETUNE,
8385

8486
LLAMA_EXAMPLE_COUNT,
8587
};
@@ -350,6 +352,12 @@ struct common_params {
350352
bool no_mmproj = false; // explicitly disable multimodal model
351353
std::vector<std::string> image; // path to image file(s)
352354

355+
// finetune
356+
struct ggml_opt_optimizer_params optimize;
357+
enum ggml_opt_optimizer_type optimizer = GGML_OPT_OPTIMIZER_TYPE_ADAMW;
358+
float val_split = 0.05f; // fraction of data used for validation when optimizing
359+
unsigned epochs = 2;
360+
353361
// embedding
354362
bool embedding = false; // get only sentence embedding
355363
int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)

examples/training/finetune.cpp

Lines changed: 36 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,31 @@
1-
#include "arg.h"
2-
#include "common.h"
3-
#include "log.h"
4-
#include "llama.h"
5-
61
#include <cmath>
72
#include <cstdio>
83
#include <cstring>
94
#include <ctime>
105
#include <vector>
116

7+
#include "arg.h"
8+
#include "common.h"
9+
#include "llama.h"
10+
#include "log.h"
11+
1212
#if defined(_MSC_VER)
13-
#pragma warning(disable: 4244 4267) // possible loss of data
13+
# pragma warning(disable : 4244 4267) // possible loss of data
1414
#endif
1515

1616
int main(int argc, char ** argv) {
1717
common_params params;
18+
struct ggml_opt_optimizer_params & optimizer_params = params.optimize;
1819

1920
params.escape = false;
2021

21-
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PERPLEXITY)) {
22+
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_FINETUNE)) {
2223
return 1;
2324
}
2425

2526
if (params.use_mmap) {
26-
LOG_INF("%s: force disabling memory mapping because it would result in-read-only pointers to the weights\n", __func__);
27+
LOG_INF("%s: force disabling memory mapping because it would result in-read-only pointers to the weights\n",
28+
__func__);
2729
params.use_mmap = false;
2830
}
2931
if (params.cache_type_k != GGML_TYPE_F32) {
@@ -38,11 +40,11 @@ int main(int argc, char ** argv) {
3840
common_init();
3941
llama_backend_init();
4042
llama_numa_init(params.numa);
41-
4243
// load the model and apply lora adapter, if any
43-
common_init_result llama_init = common_init_from_params(params);
44-
llama_model_ptr & model = llama_init.model;
45-
llama_context_ptr & ctx = llama_init.context;
44+
common_init_result llama_init = common_init_from_params(params);
45+
llama_model_ptr & model = llama_init.model;
46+
llama_context_ptr & ctx = llama_init.context;
47+
auto pctx = ctx.get();
4648

4749
if (model == NULL) {
4850
LOG_ERR("%s: unable to load model\n", __func__);
@@ -55,31 +57,34 @@ int main(int argc, char ** argv) {
5557
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
5658
}
5759

58-
constexpr float val_split = 0.05f;
59-
60-
std::vector<llama_token> tokens = common_tokenize(ctx.get(), params.prompt, true);
61-
ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get())/2);
62-
63-
struct ggml_opt_optimizer_params optimizer_params = ggml_opt_get_default_optimizer_params(nullptr);
64-
optimizer_params.adamw.alpha = 1e-7f; // learning rate
65-
66-
struct llama_opt_params lopt_params {
67-
/*n_ctx_train =*/ 0,
68-
/*param_filter =*/ llama_opt_param_filter_all,
69-
/*param_filter_ud =*/ nullptr,
70-
/*get_opt_pars =*/ ggml_opt_get_constant_optimizer_params,
71-
/*get_opt_pars_ud =*/ &optimizer_params,
60+
std::vector<llama_token> tokens = common_tokenize(pctx, params.prompt, true);
61+
ggml_opt_dataset_t dataset = common_opt_dataset_init(pctx, tokens, llama_n_ctx(pctx) / 2);
62+
63+
float alpha = optimizer_params.adamw.alpha;
64+
optimizer_params.sgd.alpha = alpha;
65+
float wd = optimizer_params.adamw.wd;
66+
optimizer_params.sgd.wd = wd;
67+
LOG_INF("-optimizer %s -lr %.2g -wd %.2g -epochs %d -val %.2g\n", ggml_opt_optimizer_name(params.optimizer),
68+
(double) alpha, (double) wd, params.epochs, (double) params.val_split);
69+
70+
struct llama_opt_params lopt_params{
71+
/*n_ctx_train =*/0,
72+
/*param_filter =*/llama_opt_param_filter_all,
73+
/*param_filter_ud =*/nullptr,
74+
/*get_opt_pars =*/ggml_opt_get_constant_optimizer_params,
75+
/*get_opt_pars_ud =*/&optimizer_params,
76+
/*optimizer_type =*/params.optimizer,
7277
};
73-
llama_opt_init(ctx.get(), model.get(), lopt_params);
78+
llama_opt_init(pctx, model.get(), lopt_params);
7479

75-
const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - val_split);
80+
const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - params.val_split);
7681

7782
ggml_opt_result_t result_train = ggml_opt_result_init();
7883
ggml_opt_result_t result_eval = ggml_opt_result_init();
7984

80-
for (int epoch = 0; epoch < 2; ++epoch) {
81-
llama_opt_epoch(ctx.get(), dataset, result_train, result_eval, idata_split,
82-
ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar);
85+
for (unsigned epoch = 0; epoch < params.epochs; ++epoch) {
86+
llama_opt_epoch(pctx, dataset, result_train, result_eval, idata_split,
87+
ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar);
8388
fprintf(stderr, "\n");
8489

8590
ggml_opt_result_reset(result_train);

ggml/include/ggml-opt.h

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -74,16 +74,30 @@ extern "C" {
7474
GGML_OPT_BUILD_TYPE_OPT = 30,
7575
};
7676

77+
enum ggml_opt_optimizer_type {
78+
GGML_OPT_OPTIMIZER_TYPE_ADAMW,
79+
GGML_OPT_OPTIMIZER_TYPE_SGD,
80+
81+
GGML_OPT_OPTIMIZER_TYPE_COUNT
82+
};
83+
84+
// "adamw" or "sgd" (case insensitive)
85+
GGML_API const char * ggml_opt_optimizer_name(enum ggml_opt_optimizer_type);
86+
GGML_API enum ggml_opt_optimizer_type ggml_opt_get_optimizer(const char *);
87+
7788
// parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
7889
struct ggml_opt_optimizer_params {
79-
// AdamW optimizer parameters
8090
struct {
81-
float alpha; // learning rate
82-
float beta1;
83-
float beta2;
84-
float eps; // epsilon for numerical stability
85-
float wd; // weight decay for AdamW, use 0.0f to disable
91+
float alpha; // learning rate
92+
float beta1; // adamw
93+
float beta2; // adamw
94+
float eps; // epsilon for numerical stability
95+
float wd; // weight decay - 0.0f to disable
8696
} adamw;
97+
struct {
98+
float alpha; // learning rate
99+
float wd; // weight decay
100+
} sgd;
87101
};
88102

89103
// callback to calculate optimizer parameters prior to a backward pass
@@ -113,7 +127,10 @@ extern "C" {
113127
int32_t opt_period; // after how many gradient accumulation steps an optimizer step should be done
114128

115129
ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
116-
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
130+
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
131+
132+
// only GGML_OPT_OPTIMIZER_TYPE_ADAMW allocates m, v per parameter
133+
enum ggml_opt_optimizer_type optimizer;
117134
};
118135

119136
// get parameters for an optimization context with defaults set where possible
@@ -186,7 +203,7 @@ extern "C" {
186203
// The second context should contain all other tensors and will be (re)allocated automatically.
187204
// Due to this automated allocation the data of the second context is not defined when accessed in user code.
188205
// Note that the second dimension of the inputs/outputs are interpreted as the number of datapoints in those tensors.
189-
// 4. Call ggml_opt_fit. If you need more control you can use ggml_opt_epoch instead.
206+
// 4. Call ggml_opt_fit. If you need more control (e.g. optimizer sgd) you can use ggml_opt_epoch instead.
190207

191208
// signature for a callback while evaluating opt_ctx on dataset, called after an evaluation
192209
typedef void (*ggml_opt_epoch_callback)(
@@ -226,12 +243,14 @@ extern "C" {
226243
struct ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
227244
ggml_opt_dataset_t dataset, // dataset with data and optionally also labels
228245
enum ggml_opt_loss_type loss_type, // loss to minimize
246+
enum ggml_opt_optimizer_type optimizer, // sgd or adamw
229247
ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t)
230248
int64_t nepoch, // how many times the dataset should be iterated over
231249
int64_t nbatch_logical, // datapoints optimizer step, must be a multiple of ndata_batch in inputs/outputs
232250
float val_split, // fraction of the dataset to use for validation, must be in [0.0f, 1.0f)
233251
bool silent); // whether or not info prints to stderr should be suppressed
234252

253+
GGML_API enum ggml_opt_optimizer_type ggml_opt_context_optimizer_type(ggml_opt_context_t);
235254
#ifdef __cplusplus
236255
}
237256
#endif

ggml/include/ggml.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@ extern "C" {
450450
GGML_OP_REPEAT_BACK,
451451
GGML_OP_CONCAT,
452452
GGML_OP_SILU_BACK,
453-
GGML_OP_NORM, // normalize
453+
GGML_OP_NORM, // normalize
454454
GGML_OP_RMS_NORM,
455455
GGML_OP_RMS_NORM_BACK,
456456
GGML_OP_GROUP_NORM,
@@ -486,7 +486,7 @@ extern "C" {
486486
GGML_OP_POOL_1D,
487487
GGML_OP_POOL_2D,
488488
GGML_OP_POOL_2D_BACK,
489-
GGML_OP_UPSCALE, // nearest interpolate
489+
GGML_OP_UPSCALE, // nearest interpolate
490490
GGML_OP_PAD,
491491
GGML_OP_PAD_REFLECT_1D,
492492
GGML_OP_ARANGE,
@@ -517,6 +517,7 @@ extern "C" {
517517
GGML_OP_CROSS_ENTROPY_LOSS,
518518
GGML_OP_CROSS_ENTROPY_LOSS_BACK,
519519
GGML_OP_OPT_STEP_ADAMW,
520+
GGML_OP_OPT_STEP_SGD,
520521

521522
GGML_OP_COUNT,
522523
};
@@ -2063,6 +2064,14 @@ extern "C" {
20632064
struct ggml_tensor * v,
20642065
struct ggml_tensor * adamw_params); // parameters such a the learning rate
20652066

2067+
// SGD (with weight decay) step
2068+
GGML_API struct ggml_tensor * ggml_opt_step_sgd(
2069+
// params: alpha (learning rate), wd (weight decay)
2070+
struct ggml_context * ctx,
2071+
struct ggml_tensor * a,
2072+
struct ggml_tensor * grad,
2073+
struct ggml_tensor * adamw_params);
2074+
20662075
//
20672076
// automatic differentiation
20682077
//

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2061,6 +2061,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
20612061
ggml_compute_forward_opt_step_adamw(params, tensor);
20622062
}
20632063
break;
2064+
case GGML_OP_OPT_STEP_SGD:
2065+
{
2066+
ggml_compute_forward_opt_step_sgd(params, tensor);
2067+
}
2068+
break;
20642069
case GGML_OP_NONE:
20652070
{
20662071
// nop
@@ -2345,6 +2350,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
23452350
case GGML_OP_CROSS_ENTROPY_LOSS:
23462351
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
23472352
case GGML_OP_OPT_STEP_ADAMW:
2353+
case GGML_OP_OPT_STEP_SGD:
23482354
{
23492355
n_tasks = n_threads;
23502356
} break;

0 commit comments

Comments
 (0)