Skip to content

Commit 96bf86c

Browse files
committed
examples/finetune -opt SGD (stochastic gradient descent) memory opt
support finetune arg -opt SGD (or sgd). llama 3.2-1b-F32 result: observed 11gb gpu ram (45 sec/epoch) when using SGD instead of 19gb (55 sec/epoch) using adamw. (getting the right learning rate for SGD is trickier than for adamw - too high and you overshoot+oscillate, too low and you waste compute slowly approaching convergence) quickly reach 99%+ train accuracy on a tiny wikipedia train (~58% token accuracy on held-out eval - reasonable) note: objective loss not directly comparable between adamw, sgd? - check perplexity or accuracy or consider relative improvements for convergence also, note that logical batch size > physical batch (gradient accumulation) seems unsupported for optimization (limited to physical , unlike in ppx - also limited to ctx-size). training quality/convergence could be improved by implementing (at cost of some memory, but you can make that up by using a much smaller physical batch for a net memory savings). presumably it's physical batch that should be limited to ctx-size? see llama_context::opt_epoch new finetune args -wd 1e-9 to enable weight decay in sgd or adamw, and max -epochs N (default 2 as before) cache (1 - wd*alpha) in 'adamw' opt struct cache computed per-epoch optimizer opts (formerly were computed twice per) add unit tested GGML_OPT_OPTIMIZER_SGD to ggml - avoids allocating m, v tensors. make ggml_opt_init aware of the optimization method since opt. memory is pre-allocated, the ggml_opt_get_optimizer_params would probably be able to change between SGD and AdamW with each epoch but would need to use adamw for the first (unconfirmed - no arg to set such a policy yet) 100 lines: train: ... loss=0.00231±0.00032 acc=99.99±0.01% t=00:00:05 val: ... loss=3.91926±nan acc=58.40±2.18% on more training data (500 lines), additional catastrophic forgetting before train reaches 99.9% accuracy: train: data=0000140/0000140 loss=0.02611±0.00077 acc=99.82±0.02% t=00:00:45 val: data=0000008/0000008 loss=4.11112±0.22526 acc=46.36±0.78% increasing batch+ctx sizes to 1536 (double what fits in memory for adamw) gets apparently better validation but that could be an artifact of continuing training from previous weights, i.e. what's train vs val probably depends on batch size. also amusing - faster due to larger batch even though larger context would be slower?: train: data=0000045/0000045 loss=0.02010±0.00138 acc=99.85±0.01% t=00:00:40 val: data=0000003/0000003 loss=1.96829±1.09488 acc=72.44±0.66%
1 parent aa59aa3 commit 96bf86c

File tree

17 files changed

+343
-86
lines changed

17 files changed

+343
-86
lines changed

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE)
1212
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
1313
endif()
1414

15+
message("CMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}")
16+
1517
# Add path to modules
1618
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
1719

common/arg.cpp

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,7 +1094,6 @@ static void common_params_print_completion(common_params_context & ctx_arg) {
10941094
"llama-embedding",
10951095
"llama-eval-callback",
10961096
"llama-export-lora",
1097-
"llama-finetune",
10981097
"llama-gen-docs",
10991098
"llama-gguf",
11001099
"llama-gguf-hash",
@@ -1125,9 +1124,10 @@ static void common_params_print_completion(common_params_context & ctx_arg) {
11251124
"llama-speculative-simple",
11261125
"llama-tokenize",
11271126
"llama-tts",
1128-
"llama-vdot" };
1127+
"llama-vdot",
1128+
"llama-finetune" };
11291129

1130-
for (const auto& exe : executables) {
1130+
for (const auto & exe : executables) {
11311131
printf("complete -F _llama_completions %s\n", exe.c_str());
11321132
}
11331133
}
@@ -1237,8 +1237,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
12371237
}
12381238
sampler_type_names.pop_back();
12391239

1240-
params.optimize = ggml_opt_get_default_optimizer_params(NULL);
1241-
params.optimize.adamw.alpha = 1e-8; // default 1e-3 is much too high for LLAMA_EXAMPLE_FINETUNE
1240+
params.optimize = ggml_opt_get_default_optimizer_params(NULL);
12421241

12431242
/**
12441243
* filter options by example
@@ -1438,13 +1437,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
14381437
params.n_predict = value;
14391438
}
14401439
).set_env("LLAMA_ARG_N_PREDICT"));
1441-
add_opt(common_arg(
1442-
{"-b", "--batch-size"}, "N",
1443-
string_format("logical maximum batch size (default: %d)", params.n_batch),
1444-
[](common_params & params, int value) {
1445-
params.n_batch = value;
1446-
}
1447-
).set_env("LLAMA_ARG_BATCH"));
1440+
add_opt(common_arg({ "-b", "--batch-size" }, "N",
1441+
string_format("logical maximum batch size (default: %d) - currently reduced to -ub in optimizer "
1442+
"(TODO: gradient accumulate?)",
1443+
params.n_batch),
1444+
[](common_params & params, int value) { params.n_batch = value; })
1445+
.set_env("LLAMA_ARG_BATCH"));
14481446
add_opt(common_arg(
14491447
{"-ub", "--ubatch-size"}, "N",
14501448
string_format("physical maximum batch size (default: %d)", params.n_ubatch),
@@ -2182,19 +2180,30 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
21822180
params.ppl_output_type = value;
21832181
}
21842182
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
2185-
add_opt(common_arg({ "-lr", "--learning-rate" }, "ALPHA",
2186-
string_format("adamw optimizer alpha (default: %.1f)", (double) params.optimize.adamw.alpha),
2187-
[](common_params & params, const std::string & value) {
2188-
params.optimize.adamw.alpha = std::stof(value);
2189-
})
2183+
add_opt(
2184+
common_arg(
2185+
{ "-lr", "--learning-rate" }, "ALPHA",
2186+
string_format(
2187+
"adamw or sgd optimizer alpha (default: %.2g); note: sgd alpha recommended ~100x (no momentum)",
2188+
(double) params.optimize.adamw.alpha),
2189+
[](common_params & params, const std::string & value) { params.optimize.adamw.alpha = std::stof(value); })
2190+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
2191+
add_opt(common_arg(
2192+
{ "-wd", "--weight-decay" }, "WD",
2193+
string_format(
2194+
"adamw or sgd optimizer weight decay (0 is off; recommend very small e.g. 1e-9) (default: %.2g).",
2195+
(double) params.optimize.adamw.wd),
2196+
[](common_params & params, const std::string & value) { params.optimize.adamw.wd = std::stof(value); })
2197+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
2198+
add_opt(common_arg({ "-epochs", "--epochs" }, "N",
2199+
string_format("optimizer max # of epochs (default: %d)", params.optimize.epochs),
2200+
[](common_params & params, int epochs) { params.optimize.epochs = epochs; })
21902201
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
21912202
add_opt(common_arg({ "-opt", "--optimizer" }, "sgd|adamw", "adamw or //TODO:sgd",
21922203
[](common_params & params, const std::string & name) {
2193-
params.optimize.optimizer = named_ggml_opt_optimizer(name.c_str());
2204+
params.optimize.optimizer = ggml_opt_get_optimizer(name.c_str());
21942205
if (params.optimize.optimizer == GGML_OPT_OPTIMIZER_COUNT) {
21952206
throw std::invalid_argument("invalid --optimizer (try adamw)");
2196-
} else if (params.optimize.optimizer == GGML_OPT_OPTIMIZER_SGD) {
2197-
throw std::invalid_argument("TODO: implement SGD");
21982207
}
21992208
})
22002209
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));

examples/training/finetune.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ int main(int argc, char ** argv) {
3838
common_init();
3939
llama_backend_init();
4040
llama_numa_init(params.numa);
41-
4241
// load the model and apply lora adapter, if any
4342
common_init_result llama_init = common_init_from_params(params);
4443
llama_model_ptr & model = llama_init.model;
@@ -61,7 +60,16 @@ int main(int argc, char ** argv) {
6160
ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get())/2);
6261

6362
struct ggml_opt_optimizer_params & optimizer_params = params.optimize;
64-
LOG_INF("-optimizer %d -lr: %.1f", optimizer_params.optimizer, (double) optimizer_params.adamw.alpha);
63+
if (optimizer_params.optimizer == GGML_OPT_OPTIMIZER_SGD) {
64+
double was = (double) optimizer_params.adamw.alpha;
65+
double by = 1e2;
66+
double to = was * by;
67+
LOG_INF("sgd multiplying -lr by %.3g (no momentum) from -lr: %.2g to %.2g\n", by, was, to);
68+
optimizer_params.adamw.alpha = to;
69+
}
70+
71+
LOG_INF("-optimizer %s -lr %.2g -wd %.2g -epochs %d\n", ggml_opt_optimizer_name(optimizer_params.optimizer),
72+
(double) optimizer_params.adamw.alpha, (double) optimizer_params.adamw.wd, optimizer_params.epochs);
6573

6674
struct llama_opt_params lopt_params {
6775
/*n_ctx_train =*/ 0,
@@ -77,7 +85,7 @@ int main(int argc, char ** argv) {
7785
ggml_opt_result_t result_train = ggml_opt_result_init();
7886
ggml_opt_result_t result_eval = ggml_opt_result_init();
7987

80-
for (int epoch = 0; epoch < 2; ++epoch) {
88+
for (unsigned epoch = 0; epoch < optimizer_params.epochs; ++epoch) {
8189
llama_opt_epoch(ctx.get(), dataset, result_train, result_eval, idata_split,
8290
ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar);
8391
fprintf(stderr, "\n");

ggml/include/ggml-opt.h

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -74,28 +74,33 @@ extern "C" {
7474
GGML_OPT_BUILD_TYPE_OPT = 30,
7575
};
7676

77-
enum ggml_opt_optimizer {
77+
enum ggml_opt_optimizer_type {
7878
GGML_OPT_OPTIMIZER_ADAMW,
7979
GGML_OPT_OPTIMIZER_SGD,
8080

8181
GGML_OPT_OPTIMIZER_COUNT
8282
};
8383

8484
// "adamw" or "sgd" (case insensitive)
85-
GGML_API const char * ggml_opt_optimizer_name(enum ggml_opt_optimizer);
86-
GGML_API enum ggml_opt_optimizer named_ggml_opt_optimizer(const char *);
85+
GGML_API const char * ggml_opt_optimizer_name(enum ggml_opt_optimizer_type);
86+
GGML_API enum ggml_opt_optimizer_type ggml_opt_get_optimizer(const char *);
8787

8888
// parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
8989
struct ggml_opt_optimizer_params {
90-
// AdamW optimizer parameters
90+
// SGD and AdamW optimizer parameters
9191
struct {
9292
float alpha; // learning rate
93-
float beta1;
94-
float beta2;
93+
float beta1; // adamw
94+
float beta2; // adamw
9595
float eps; // epsilon for numerical stability
96-
float wd; // weight decay for AdamW, use 0.0f to disable
96+
float wd; // weight decay for SGD or AdamW, use 0.0f to disable
9797
} adamw;
98-
enum ggml_opt_optimizer optimizer;
98+
99+
// only GGML_OPT_OPTIMIZER_ADAMW allocates m, v per parameter
100+
enum ggml_opt_optimizer_type optimizer;
101+
102+
// affects finetune.cpp only so far:
103+
unsigned epochs; // max # of epochs sampling over training data
99104
};
100105

101106
// callback to calculate optimizer parameters prior to a backward pass
@@ -126,6 +131,8 @@ extern "C" {
126131

127132
ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
128133
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
134+
struct ggml_opt_optimizer_params
135+
opt_params; // holds result of get_opt_pars(get_opt_pars_ud) after ggml_opt_init (could call get_opt_pars repeatedly instead)
129136
};
130137

131138
// get parameters for an optimization context with defaults set where possible

ggml/include/ggml.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@ extern "C" {
450450
GGML_OP_REPEAT_BACK,
451451
GGML_OP_CONCAT,
452452
GGML_OP_SILU_BACK,
453-
GGML_OP_NORM, // normalize
453+
GGML_OP_NORM, // normalize
454454
GGML_OP_RMS_NORM,
455455
GGML_OP_RMS_NORM_BACK,
456456
GGML_OP_GROUP_NORM,
@@ -486,7 +486,7 @@ extern "C" {
486486
GGML_OP_POOL_1D,
487487
GGML_OP_POOL_2D,
488488
GGML_OP_POOL_2D_BACK,
489-
GGML_OP_UPSCALE, // nearest interpolate
489+
GGML_OP_UPSCALE, // nearest interpolate
490490
GGML_OP_PAD,
491491
GGML_OP_PAD_REFLECT_1D,
492492
GGML_OP_ARANGE,
@@ -517,6 +517,7 @@ extern "C" {
517517
GGML_OP_CROSS_ENTROPY_LOSS,
518518
GGML_OP_CROSS_ENTROPY_LOSS_BACK,
519519
GGML_OP_OPT_STEP_ADAMW,
520+
GGML_OP_OPT_STEP_SGD,
520521

521522
GGML_OP_COUNT,
522523
};
@@ -2063,6 +2064,11 @@ extern "C" {
20632064
struct ggml_tensor * v,
20642065
struct ggml_tensor * adamw_params); // parameters such a the learning rate
20652066

2067+
// SGD (with weight decay) step
2068+
GGML_API struct ggml_tensor * ggml_opt_step_sgd(
2069+
struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * grad,
2070+
struct ggml_tensor * adamw_params); // parameters: alpha, the learning rate, and wd, weight decay
2071+
20662072
//
20672073
// automatic differentiation
20682074
//

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2057,6 +2057,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
20572057
ggml_compute_forward_opt_step_adamw(params, tensor);
20582058
}
20592059
break;
2060+
case GGML_OP_OPT_STEP_SGD:
2061+
{
2062+
ggml_compute_forward_opt_step_sgd(params, tensor);
2063+
}
2064+
break;
20602065
case GGML_OP_NONE:
20612066
{
20622067
// nop
@@ -2341,6 +2346,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
23412346
case GGML_OP_CROSS_ENTROPY_LOSS:
23422347
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
23432348
case GGML_OP_OPT_STEP_ADAMW:
2349+
case GGML_OP_OPT_STEP_SGD:
23442350
{
23452351
n_tasks = n_threads;
23462352
} break;

ggml/src/ggml-cpu/ops.cpp

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8831,7 +8831,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
88318831
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
88328832
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m));
88338833
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v));
8834-
GGML_ASSERT(ggml_nelements(adamw_params) == 7);
8834+
GGML_ASSERT(ggml_nelements(adamw_params) == 8);
88358835

88368836
const int ith = params->ith;
88378837
const int nth = params->nth;
@@ -8849,14 +8849,14 @@ static void ggml_compute_forward_opt_step_adamw_f32(
88498849
const int ir1 = MIN(ir0 + dr, nr);
88508850

88518851
const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);
8852+
88528853
const float alpha = adamw_params_ptr[0];
88538854
const float beta1 = adamw_params_ptr[1];
88548855
const float beta2 = adamw_params_ptr[2];
88558856
const float eps = adamw_params_ptr[3];
8856-
const float wd = adamw_params_ptr[4];
88578857
const float beta1h = adamw_params_ptr[5];
88588858
const float beta2h = adamw_params_ptr[6];
8859-
8859+
const float keep = adamw_params_ptr[7];
88608860
for (int ir = ir0; ir < ir1; ++ir) {
88618861
const int64_t i03 = ir/(ne02*ne01);
88628862
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
@@ -8879,7 +8879,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
88798879
// The weight decay is applied independently of the Adam momenta m and v.
88808880
// This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
88818881
// See: https://arxiv.org/pdf/1711.05101v3.pdf
8882-
w[i00] = w[i00]*(1.0f - alpha*wd) - alpha*mh/vh;
8882+
w[i00] = w[i00] * keep - alpha * mh / vh;
88838883
}
88848884
}
88858885
}
@@ -8901,3 +8901,63 @@ void ggml_compute_forward_opt_step_adamw(
89018901
}
89028902
}
89038903
}
8904+
8905+
static void ggml_compute_forward_opt_step_sgd_f32(const ggml_compute_params * params, ggml_tensor * dst) {
8906+
const ggml_tensor * src0 = dst->src[0];
8907+
const ggml_tensor * src0_grad = dst->src[1];
8908+
const ggml_tensor * adamw_params = dst->src[2];
8909+
8910+
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
8911+
GGML_ASSERT(ggml_nelements(adamw_params) == 8);
8912+
8913+
const int ith = params->ith;
8914+
const int nth = params->nth;
8915+
8916+
const int nr = ggml_nrows(src0);
8917+
8918+
GGML_TENSOR_UNARY_OP_LOCALS
8919+
GGML_ASSERT(nb00 == sizeof(float));
8920+
8921+
// rows per thread
8922+
const int dr = (nr + nth - 1) / nth;
8923+
8924+
// row range for this thread
8925+
const int ir0 = dr * ith;
8926+
const int ir1 = MIN(ir0 + dr, nr);
8927+
8928+
// using adamw param subset we care about - alpha, wd - could have a separate struct
8929+
const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);
8930+
const float alpha = adamw_params_ptr[0];
8931+
const float keep = adamw_params_ptr[7];
8932+
8933+
for (int ir = ir0; ir < ir1; ++ir) {
8934+
const int64_t i03 = ir / (ne02 * ne01);
8935+
const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01;
8936+
const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01);
8937+
8938+
const size_t offset = i03 * nb03 + i02 * nb02 + i01 * nb01;
8939+
8940+
float * w = (float *) ((char *) src0->data + offset); // weight
8941+
const float * g = (const float *) ((const char *) src0_grad->data + offset); // grad
8942+
8943+
for (int i00 = 0; i00 < ne00; ++i00) {
8944+
w[i00] = w[i00] * keep - alpha * g[i00];
8945+
}
8946+
}
8947+
}
8948+
8949+
void ggml_compute_forward_opt_step_sgd(const ggml_compute_params * params, ggml_tensor * dst) {
8950+
const ggml_tensor * src0 = dst->src[0];
8951+
8952+
switch (src0->type) {
8953+
case GGML_TYPE_F32:
8954+
{
8955+
ggml_compute_forward_opt_step_sgd_f32(params, dst);
8956+
}
8957+
break;
8958+
default:
8959+
{
8960+
GGML_ABORT("fatal error - sgd is F32 only");
8961+
}
8962+
}
8963+
}

ggml/src/ggml-cpu/ops.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ void ggml_compute_forward_custom(const struct ggml_compute_params * params, stru
104104
void ggml_compute_forward_cross_entropy_loss(const struct ggml_compute_params * params, struct ggml_tensor * dst);
105105
void ggml_compute_forward_cross_entropy_loss_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
106106
void ggml_compute_forward_opt_step_adamw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
107-
107+
void ggml_compute_forward_opt_step_sgd(const struct ggml_compute_params * params, struct ggml_tensor * dst);
108108
#ifdef __cplusplus
109109
}
110110
#endif

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "ggml-cuda/mmvq.cuh"
2525
#include "ggml-cuda/norm.cuh"
2626
#include "ggml-cuda/opt-step-adamw.cuh"
27+
#include "ggml-cuda/opt-step-sgd.cuh"
2728
#include "ggml-cuda/out-prod.cuh"
2829
#include "ggml-cuda/pad.cuh"
2930
#include "ggml-cuda/pool2d.cuh"
@@ -2352,6 +2353,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
23522353
case GGML_OP_OPT_STEP_ADAMW:
23532354
ggml_cuda_opt_step_adamw(ctx, dst);
23542355
break;
2356+
case GGML_OP_OPT_STEP_SGD:
2357+
ggml_cuda_opt_step_sgd(ctx, dst);
2358+
break;
23552359
default:
23562360
return false;
23572361
}
@@ -3256,6 +3260,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
32563260
case GGML_OP_CROSS_ENTROPY_LOSS:
32573261
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
32583262
case GGML_OP_OPT_STEP_ADAMW:
3263+
case GGML_OP_OPT_STEP_SGD:
32593264
return true;
32603265
default:
32613266
return false;

0 commit comments

Comments
 (0)