Skip to content

Commit a746243

Browse files
committed
examples/finetune -opt SGD (stochastic gradient descent) memory opt
support finetune arg -opt SGD (or sgd). llama 3.2-1b-F32 result: observed 11gb gpu ram (45 sec/epoch) when using SGD instead of 19gb (55 sec/epoch) using adamw. (getting the right learning rate for SGD is trickier than for adamw - too high and you overshoot+oscillate, too low and you waste compute slowly approaching convergence) quickly reach 99%+ train accuracy on a tiny wikipedia train (~58% token accuracy on held-out eval - reasonable) note: objective loss not directly comparable between adamw, sgd - check perplexity or accuracy or consider relative improvements for convergence train: ... loss=0.00231±0.00032 acc=99.99±0.01% t=00:00:05 val: ... loss=3.91926±nan acc=58.40±2.18% on more training data (500 lines), additional catastrophic forgetting before train reaches 99.9% accuracy: train: data=0000140/0000140 loss=0.02752±0.00094 acc=99.78±0.02% t=00:00:45 val: data=0000008/0000008 loss=4.16029±0.23384 acc=46.61±0.78% new finetune args -wd 1e-9 to enable weight decay in sgd or adamw, and max -epochs N (default 2 as before) cache (1 - wd*alpha) in 'adamw' opt struct cache computed per-epoch optimizer opts (formerly were computed twice per) add unit tested GGML_OPT_OPTIMIZER_SGD to ggml - avoids allocating m, v tensors. make ggml_opt_init aware of the optimization method since opt. memory is pre-allocated, the ggml_opt_get_optimizer_params would probably be able to change between SGD and AdamW with each epoch but would need to use adamw for the first (unconfirmed - no arg to set such a policy yet)
1 parent aa59aa3 commit a746243

File tree

16 files changed

+336
-82
lines changed

16 files changed

+336
-82
lines changed

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE)
1212
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
1313
endif()
1414

15+
message("CMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}")
16+
1517
# Add path to modules
1618
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
1719

common/arg.cpp

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1237,8 +1237,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
12371237
}
12381238
sampler_type_names.pop_back();
12391239

1240-
params.optimize = ggml_opt_get_default_optimizer_params(NULL);
1241-
params.optimize.adamw.alpha = 1e-8; // default 1e-3 is much too high for LLAMA_EXAMPLE_FINETUNE
1240+
params.optimize = ggml_opt_get_default_optimizer_params(NULL);
12421241

12431242
/**
12441243
* filter options by example
@@ -1438,13 +1437,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
14381437
params.n_predict = value;
14391438
}
14401439
).set_env("LLAMA_ARG_N_PREDICT"));
1441-
add_opt(common_arg(
1442-
{"-b", "--batch-size"}, "N",
1443-
string_format("logical maximum batch size (default: %d)", params.n_batch),
1444-
[](common_params & params, int value) {
1445-
params.n_batch = value;
1446-
}
1447-
).set_env("LLAMA_ARG_BATCH"));
1440+
add_opt(common_arg({ "-b", "--batch-size" }, "N",
1441+
string_format("logical maximum batch size (default: %d) - currently reduced to -ub in optimizer "
1442+
"(TODO: gradient accumulate?)",
1443+
params.n_batch),
1444+
[](common_params & params, int value) { params.n_batch = value; })
1445+
.set_env("LLAMA_ARG_BATCH"));
14481446
add_opt(common_arg(
14491447
{"-ub", "--ubatch-size"}, "N",
14501448
string_format("physical maximum batch size (default: %d)", params.n_ubatch),
@@ -2182,19 +2180,30 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
21822180
params.ppl_output_type = value;
21832181
}
21842182
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
2185-
add_opt(common_arg({ "-lr", "--learning-rate" }, "ALPHA",
2186-
string_format("adamw optimizer alpha (default: %.1f)", (double) params.optimize.adamw.alpha),
2187-
[](common_params & params, const std::string & value) {
2188-
params.optimize.adamw.alpha = std::stof(value);
2189-
})
2183+
add_opt(
2184+
common_arg(
2185+
{ "-lr", "--learning-rate" }, "ALPHA",
2186+
string_format(
2187+
"adamw or sgd optimizer alpha (default: %.2g); note: sgd alpha recommended ~100x (no momentum)",
2188+
(double) params.optimize.adamw.alpha),
2189+
[](common_params & params, const std::string & value) { params.optimize.adamw.alpha = std::stof(value); })
2190+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
2191+
add_opt(common_arg(
2192+
{ "-wd", "--weight-decay" }, "WD",
2193+
string_format(
2194+
"adamw or sgd optimizer weight decay (0 is off; recommend very small e.g. 1e-9) (default: %.2g).",
2195+
(double) params.optimize.adamw.wd),
2196+
[](common_params & params, const std::string & value) { params.optimize.adamw.wd = std::stof(value); })
2197+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
2198+
add_opt(common_arg({ "-epochs", "--epochs" }, "N",
2199+
string_format("optimizer max # of epochs (default: %d)", params.optimize.epochs),
2200+
[](common_params & params, int epochs) { params.optimize.epochs = epochs; })
21902201
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
21912202
add_opt(common_arg({ "-opt", "--optimizer" }, "sgd|adamw", "adamw or //TODO:sgd",
21922203
[](common_params & params, const std::string & name) {
21932204
params.optimize.optimizer = named_ggml_opt_optimizer(name.c_str());
21942205
if (params.optimize.optimizer == GGML_OPT_OPTIMIZER_COUNT) {
21952206
throw std::invalid_argument("invalid --optimizer (try adamw)");
2196-
} else if (params.optimize.optimizer == GGML_OPT_OPTIMIZER_SGD) {
2197-
throw std::invalid_argument("TODO: implement SGD");
21982207
}
21992208
})
22002209
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));

examples/training/finetune.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ int main(int argc, char ** argv) {
3838
common_init();
3939
llama_backend_init();
4040
llama_numa_init(params.numa);
41-
4241
// load the model and apply lora adapter, if any
4342
common_init_result llama_init = common_init_from_params(params);
4443
llama_model_ptr & model = llama_init.model;
@@ -61,7 +60,16 @@ int main(int argc, char ** argv) {
6160
ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get())/2);
6261

6362
struct ggml_opt_optimizer_params & optimizer_params = params.optimize;
64-
LOG_INF("-optimizer %d -lr: %.1f", optimizer_params.optimizer, (double) optimizer_params.adamw.alpha);
63+
if (optimizer_params.optimizer == GGML_OPT_OPTIMIZER_SGD) {
64+
double was = (double) optimizer_params.adamw.alpha;
65+
double by = 1e2;
66+
double to = was * by;
67+
LOG_INF("sgd multiplying -lr by %.3g (no momentum) from -lr: %.2g to %.2g\n", by, was, to);
68+
optimizer_params.adamw.alpha = to;
69+
}
70+
71+
LOG_INF("-optimizer %s -lr %.2g -wd %.2g -epochs %d\n", ggml_opt_optimizer_name(optimizer_params.optimizer),
72+
(double) optimizer_params.adamw.alpha, (double) optimizer_params.adamw.wd, optimizer_params.epochs);
6573

6674
struct llama_opt_params lopt_params {
6775
/*n_ctx_train =*/ 0,
@@ -77,7 +85,7 @@ int main(int argc, char ** argv) {
7785
ggml_opt_result_t result_train = ggml_opt_result_init();
7886
ggml_opt_result_t result_eval = ggml_opt_result_init();
7987

80-
for (int epoch = 0; epoch < 2; ++epoch) {
88+
for (unsigned epoch = 0; epoch < optimizer_params.epochs; ++epoch) {
8189
llama_opt_epoch(ctx.get(), dataset, result_train, result_eval, idata_split,
8290
ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar);
8391
fprintf(stderr, "\n");

ggml/include/ggml-opt.h

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -74,28 +74,33 @@ extern "C" {
7474
GGML_OPT_BUILD_TYPE_OPT = 30,
7575
};
7676

77-
enum ggml_opt_optimizer {
77+
enum ggml_opt_optimizer_type {
7878
GGML_OPT_OPTIMIZER_ADAMW,
7979
GGML_OPT_OPTIMIZER_SGD,
8080

8181
GGML_OPT_OPTIMIZER_COUNT
8282
};
8383

8484
// "adamw" or "sgd" (case insensitive)
85-
GGML_API const char * ggml_opt_optimizer_name(enum ggml_opt_optimizer);
86-
GGML_API enum ggml_opt_optimizer named_ggml_opt_optimizer(const char *);
85+
GGML_API const char * ggml_opt_optimizer_name(enum ggml_opt_optimizer_type);
86+
GGML_API enum ggml_opt_optimizer_type named_ggml_opt_optimizer(const char *);
8787

8888
// parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
8989
struct ggml_opt_optimizer_params {
90-
// AdamW optimizer parameters
90+
// SGD and AdamW optimizer parameters
9191
struct {
9292
float alpha; // learning rate
93-
float beta1;
94-
float beta2;
93+
float beta1; // adamw
94+
float beta2; // adamw
9595
float eps; // epsilon for numerical stability
96-
float wd; // weight decay for AdamW, use 0.0f to disable
96+
float wd; // weight decay for SGD or AdamW, use 0.0f to disable
9797
} adamw;
98-
enum ggml_opt_optimizer optimizer;
98+
99+
// only GGML_OPT_OPTIMIZER_ADAMW allocates m, v per parameter
100+
enum ggml_opt_optimizer_type optimizer;
101+
102+
// affects finetune.cpp only so far:
103+
unsigned epochs; // max # of epochs sampling over training data
99104
};
100105

101106
// callback to calculate optimizer parameters prior to a backward pass
@@ -126,6 +131,8 @@ extern "C" {
126131

127132
ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
128133
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
134+
struct ggml_opt_optimizer_params
135+
opt_params; // holds result of get_opt_pars(get_opt_pars_ud) after ggml_opt_init (could call get_opt_pars repeatedly instead)
129136
};
130137

131138
// get parameters for an optimization context with defaults set where possible

ggml/include/ggml.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@ extern "C" {
450450
GGML_OP_REPEAT_BACK,
451451
GGML_OP_CONCAT,
452452
GGML_OP_SILU_BACK,
453-
GGML_OP_NORM, // normalize
453+
GGML_OP_NORM, // normalize
454454
GGML_OP_RMS_NORM,
455455
GGML_OP_RMS_NORM_BACK,
456456
GGML_OP_GROUP_NORM,
@@ -486,7 +486,7 @@ extern "C" {
486486
GGML_OP_POOL_1D,
487487
GGML_OP_POOL_2D,
488488
GGML_OP_POOL_2D_BACK,
489-
GGML_OP_UPSCALE, // nearest interpolate
489+
GGML_OP_UPSCALE, // nearest interpolate
490490
GGML_OP_PAD,
491491
GGML_OP_PAD_REFLECT_1D,
492492
GGML_OP_ARANGE,
@@ -517,6 +517,7 @@ extern "C" {
517517
GGML_OP_CROSS_ENTROPY_LOSS,
518518
GGML_OP_CROSS_ENTROPY_LOSS_BACK,
519519
GGML_OP_OPT_STEP_ADAMW,
520+
GGML_OP_OPT_STEP_SGD,
520521

521522
GGML_OP_COUNT,
522523
};
@@ -2063,6 +2064,11 @@ extern "C" {
20632064
struct ggml_tensor * v,
20642065
struct ggml_tensor * adamw_params); // parameters such a the learning rate
20652066

2067+
// SGD (with weight decay) step
2068+
GGML_API struct ggml_tensor * ggml_opt_step_sgd(
2069+
struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * grad,
2070+
struct ggml_tensor * adamw_params); // parameters: alpha, the learning rate, and wd, weight decay
2071+
20662072
//
20672073
// automatic differentiation
20682074
//

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2057,6 +2057,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
20572057
ggml_compute_forward_opt_step_adamw(params, tensor);
20582058
}
20592059
break;
2060+
case GGML_OP_OPT_STEP_SGD:
2061+
{
2062+
ggml_compute_forward_opt_step_sgd(params, tensor);
2063+
}
2064+
break;
20602065
case GGML_OP_NONE:
20612066
{
20622067
// nop
@@ -2341,6 +2346,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
23412346
case GGML_OP_CROSS_ENTROPY_LOSS:
23422347
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
23432348
case GGML_OP_OPT_STEP_ADAMW:
2349+
case GGML_OP_OPT_STEP_SGD:
23442350
{
23452351
n_tasks = n_threads;
23462352
} break;

ggml/src/ggml-cpu/ops.cpp

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8831,7 +8831,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
88318831
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
88328832
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m));
88338833
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v));
8834-
GGML_ASSERT(ggml_nelements(adamw_params) == 7);
8834+
GGML_ASSERT(ggml_nelements(adamw_params) == 8);
88358835

88368836
const int ith = params->ith;
88378837
const int nth = params->nth;
@@ -8849,14 +8849,14 @@ static void ggml_compute_forward_opt_step_adamw_f32(
88498849
const int ir1 = MIN(ir0 + dr, nr);
88508850

88518851
const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);
8852+
88528853
const float alpha = adamw_params_ptr[0];
88538854
const float beta1 = adamw_params_ptr[1];
88548855
const float beta2 = adamw_params_ptr[2];
88558856
const float eps = adamw_params_ptr[3];
8856-
const float wd = adamw_params_ptr[4];
88578857
const float beta1h = adamw_params_ptr[5];
88588858
const float beta2h = adamw_params_ptr[6];
8859-
8859+
const float keep = adamw_params_ptr[7];
88608860
for (int ir = ir0; ir < ir1; ++ir) {
88618861
const int64_t i03 = ir/(ne02*ne01);
88628862
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
@@ -8879,7 +8879,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
88798879
// The weight decay is applied independently of the Adam momenta m and v.
88808880
// This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
88818881
// See: https://arxiv.org/pdf/1711.05101v3.pdf
8882-
w[i00] = w[i00]*(1.0f - alpha*wd) - alpha*mh/vh;
8882+
w[i00] = w[i00] * keep - alpha * mh / vh;
88838883
}
88848884
}
88858885
}
@@ -8901,3 +8901,63 @@ void ggml_compute_forward_opt_step_adamw(
89018901
}
89028902
}
89038903
}
8904+
8905+
static void ggml_compute_forward_opt_step_sgd_f32(const ggml_compute_params * params, ggml_tensor * dst) {
8906+
const ggml_tensor * src0 = dst->src[0];
8907+
const ggml_tensor * src0_grad = dst->src[1];
8908+
const ggml_tensor * adamw_params = dst->src[2];
8909+
8910+
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
8911+
GGML_ASSERT(ggml_nelements(adamw_params) == 8);
8912+
8913+
const int ith = params->ith;
8914+
const int nth = params->nth;
8915+
8916+
const int nr = ggml_nrows(src0);
8917+
8918+
GGML_TENSOR_UNARY_OP_LOCALS
8919+
GGML_ASSERT(nb00 == sizeof(float));
8920+
8921+
// rows per thread
8922+
const int dr = (nr + nth - 1) / nth;
8923+
8924+
// row range for this thread
8925+
const int ir0 = dr * ith;
8926+
const int ir1 = MIN(ir0 + dr, nr);
8927+
8928+
// using adamw param subset we care about - alpha, wd - could have a separate struct
8929+
const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);
8930+
const float alpha = adamw_params_ptr[0];
8931+
const float keep = adamw_params_ptr[7];
8932+
8933+
for (int ir = ir0; ir < ir1; ++ir) {
8934+
const int64_t i03 = ir / (ne02 * ne01);
8935+
const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01;
8936+
const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01);
8937+
8938+
const size_t offset = i03 * nb03 + i02 * nb02 + i01 * nb01;
8939+
8940+
float * w = (float *) ((char *) src0->data + offset); // weight
8941+
const float * g = (const float *) ((const char *) src0_grad->data + offset); // grad
8942+
8943+
for (int i00 = 0; i00 < ne00; ++i00) {
8944+
w[i00] = w[i00] * keep - alpha * g[i00];
8945+
}
8946+
}
8947+
}
8948+
8949+
void ggml_compute_forward_opt_step_sgd(const ggml_compute_params * params, ggml_tensor * dst) {
8950+
const ggml_tensor * src0 = dst->src[0];
8951+
8952+
switch (src0->type) {
8953+
case GGML_TYPE_F32:
8954+
{
8955+
ggml_compute_forward_opt_step_sgd_f32(params, dst);
8956+
}
8957+
break;
8958+
default:
8959+
{
8960+
GGML_ABORT("fatal error - sgd is F32 only");
8961+
}
8962+
}
8963+
}

ggml/src/ggml-cpu/ops.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ void ggml_compute_forward_custom(const struct ggml_compute_params * params, stru
104104
void ggml_compute_forward_cross_entropy_loss(const struct ggml_compute_params * params, struct ggml_tensor * dst);
105105
void ggml_compute_forward_cross_entropy_loss_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
106106
void ggml_compute_forward_opt_step_adamw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
107-
107+
void ggml_compute_forward_opt_step_sgd(const struct ggml_compute_params * params, struct ggml_tensor * dst);
108108
#ifdef __cplusplus
109109
}
110110
#endif

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "ggml-cuda/mmvq.cuh"
2525
#include "ggml-cuda/norm.cuh"
2626
#include "ggml-cuda/opt-step-adamw.cuh"
27+
#include "ggml-cuda/opt-step-sgd.cuh"
2728
#include "ggml-cuda/out-prod.cuh"
2829
#include "ggml-cuda/pad.cuh"
2930
#include "ggml-cuda/pool2d.cuh"
@@ -2352,6 +2353,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
23522353
case GGML_OP_OPT_STEP_ADAMW:
23532354
ggml_cuda_opt_step_adamw(ctx, dst);
23542355
break;
2356+
case GGML_OP_OPT_STEP_SGD:
2357+
ggml_cuda_opt_step_sgd(ctx, dst);
2358+
break;
23552359
default:
23562360
return false;
23572361
}
@@ -3256,6 +3260,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
32563260
case GGML_OP_CROSS_ENTROPY_LOSS:
32573261
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
32583262
case GGML_OP_OPT_STEP_ADAMW:
3263+
case GGML_OP_OPT_STEP_SGD:
32593264
return true;
32603265
default:
32613266
return false;

ggml/src/ggml-cuda/opt-step-adamw.cu

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ static __global__ void opt_step_adamw_f32(
1717
const float beta1 = pars[1];
1818
const float beta2 = pars[2];
1919
const float eps = pars[3];
20-
const float wd = pars[4];
2120
const float beta1h = pars[5];
2221
const float beta2h = pars[6];
22+
const float keep = pars[7];
2323

2424
const float gi = g[i];
2525
const float gmi = g_m[i]*beta1 + gi*(1.0f - beta1);
@@ -31,7 +31,11 @@ static __global__ void opt_step_adamw_f32(
3131
const float mh = gmi*beta1h;
3232
const float vh = sqrtf(gvi*beta2h) + eps;
3333

34-
x[i] = x[i]*(1.0f - alpha*wd) - alpha*mh/vh;
34+
#if 1
35+
x[i] = x[i] * (1.f - alpha * pars[4]) - alpha * mh / vh;
36+
#else
37+
x[i] = x[i] * keep - alpha * mh / vh;
38+
#endif
3539
}
3640

3741
static void opt_step_adamw_f32_cuda(
@@ -62,14 +66,13 @@ void ggml_cuda_opt_step_adamw(ggml_backend_cuda_context & ctx, ggml_tensor * dst
6266
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
6367
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m));
6468
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v));
65-
GGML_ASSERT(ggml_nelements(adamw_params) == 7);
69+
GGML_ASSERT(ggml_nelements(adamw_params) == 8);
6670

6771
float * src0_d = (float *) src0->data;
6872
const float * src0_grad_d = (const float *) src0_grad->data;
6973
float * src0_grad_m_d = (float *) src0_grad_m->data;
7074
float * src0_grad_v_d = (float *) src0_grad_v->data;
7175
const float * adamw_params_d = (const float *) adamw_params->data;
72-
7376
cudaStream_t stream = ctx.stream();
7477

7578
const int64_t ne = ggml_nelements(src0);

0 commit comments

Comments
 (0)