Skip to content

Commit 085f870

Browse files
committed
examples/finetune -opt SGD (stochastic gradient descent) memory opt
support finetune arg -opt SGD (or sgd). llama 3.2-1b-F32 result: observed 11gb gpu ram (45 sec/epoch) when using SGD instead of 19gb (55 sec/epoch) using adamw. (getting the right learning rate for SGD is trickier than for adamw - too high and you overshoot+oscillate, too low and you waste compute slowly approaching convergence) SGD (or adamw) quickly reach 99%+ train accuracy. note: objective loss not directly comparable between adamw, sgd? - check perplexity or accuracy or consider relative improvements for convergence also, note that logical batch size > physical batch (gradient accumulation) seems unsupported for optimization (limited to physical , unlike in ppx - also limited to ctx-size). training quality/convergence could be improved by implementing (at cost of some memory, but you can make that up by using a much smaller physical batch for a net memory savings). presumably it's physical batch that should be limited to ctx-size? see llama_context::opt_epoch new finetune args -wd 1e-9 to enable weight decay in sgd or adamw, and max -epochs N (default 2 as before) cache (1 - wd*alpha) in 'adamw' opt struct - no noticeable perf benefit cache computed per-epoch optimizer opts (formerly were computed twice per) add unit tested GGML_OPT_OPTIMIZER_SGD to ggml - avoids allocating m, v tensors. make ggml_opt_init aware of the optimization method since opt. memory is pre-allocated, the ggml_opt_get_optimizer_params would probably be able to change between SGD and AdamW with each epoch but would need to use adamw for the first (unconfirmed - no arg to set such a policy yet) 100 lines of wikipedia train: train: ... loss=0.00231±0.00032 acc=99.99±0.01% t=00:00:05 val: ... loss=3.91926±nan acc=58.40±2.18% on more training data (500 lines), additional catastrophic forgetting before train reaches 99.9% accuracy: train: data=0000140/0000140 loss=0.02611±0.00077 acc=99.82±0.02% t=00:00:45 val: data=0000008/0000008 loss=4.11112±0.22526 acc=46.36±0.78% increasing batch+ctx sizes to 1536 (double what fits in memory for adamw) gets apparently better validation but that could be an artifact of continuing training from previous weights, i.e. what's train vs val probably depends on batch size. also amusing - faster due to larger batch even though larger context would be slower?: train: data=0000045/0000045 loss=0.01722±0.00103 acc=99.90±0.01% t=00:00:40 val: data=0000003/0000003 loss=1.96829±1.09488 acc=72.44±0.66%
1 parent ed5bda6 commit 085f870

File tree

20 files changed

+1703
-1631
lines changed

20 files changed

+1703
-1631
lines changed

.clang-format

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ AllowShortLambdasOnASingleLine: Inline
2323
AllowShortLoopsOnASingleLine: false
2424
AlwaysBreakBeforeMultilineStrings: true
2525
BinPackArguments: true
26-
BinPackParameters: true # OnePerLine
26+
BinPackParameters: true
2727
BitFieldColonSpacing: Both
2828
BreakBeforeBraces: Custom # Attach
2929
BraceWrapping:
@@ -45,7 +45,6 @@ BraceWrapping:
4545
SplitEmptyFunction: false
4646
SplitEmptyRecord: false
4747
SplitEmptyNamespace: false
48-
# BreakAdjacentStringLiterals: true
4948
BreakAfterAttributes: Never
5049
BreakBeforeBinaryOperators: None
5150
BreakBeforeInlineASMColon: OnlyMultiline
@@ -158,4 +157,3 @@ TabWidth: 4
158157
UseTab: Never
159158
WhitespaceSensitiveMacros: ['STRINGIZE']
160159
...
161-

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE)
1212
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
1313
endif()
1414

15+
message("CMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}")
16+
1517
# Add path to modules
1618
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
1719

common/arg.cpp

Lines changed: 1375 additions & 1548 deletions
Large diffs are not rendered by default.

common/common.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,8 @@ struct common_params {
354354

355355
// finetune
356356
struct ggml_opt_optimizer_params optimize;
357+
unsigned epochs = 2;
358+
357359
// embedding
358360
bool embedding = false; // get only sentence embedding
359361
int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)

examples/training/finetune.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ int main(int argc, char ** argv) {
3838
common_init();
3939
llama_backend_init();
4040
llama_numa_init(params.numa);
41-
4241
// load the model and apply lora adapter, if any
4342
common_init_result llama_init = common_init_from_params(params);
4443
llama_model_ptr & model = llama_init.model;
@@ -61,7 +60,16 @@ int main(int argc, char ** argv) {
6160
ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get())/2);
6261

6362
struct ggml_opt_optimizer_params & optimizer_params = params.optimize;
64-
LOG_INF("-optimizer %d -lr: %.1f", optimizer_params.optimizer, (double) optimizer_params.adamw.alpha);
63+
if (optimizer_params.optimizer == GGML_OPT_OPTIMIZER_SGD) {
64+
double was = (double) optimizer_params.common.alpha;
65+
double by = 1e2;
66+
double to = was * by;
67+
LOG_INF("sgd multiplying -lr by %.3g (no momentum) from -lr: %.2g to %.2g\n", by, was, to);
68+
optimizer_params.common.alpha = to;
69+
}
70+
71+
LOG_INF("-optimizer %s -lr %.2g -wd %.2g -epochs %d\n", ggml_opt_optimizer_name(optimizer_params.optimizer),
72+
(double) optimizer_params.common.alpha, (double) optimizer_params.common.wd, params.epochs);
6573

6674
struct llama_opt_params lopt_params {
6775
/*n_ctx_train =*/ 0,
@@ -77,7 +85,7 @@ int main(int argc, char ** argv) {
7785
ggml_opt_result_t result_train = ggml_opt_result_init();
7886
ggml_opt_result_t result_eval = ggml_opt_result_init();
7987

80-
for (int epoch = 0; epoch < 2; ++epoch) {
88+
for (unsigned epoch = 0; epoch < params.epochs; ++epoch) {
8189
llama_opt_epoch(ctx.get(), dataset, result_train, result_eval, idata_split,
8290
ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar);
8391
fprintf(stderr, "\n");

ggml/include/ggml-opt.h

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -74,28 +74,34 @@ extern "C" {
7474
GGML_OPT_BUILD_TYPE_OPT = 30,
7575
};
7676

77-
enum ggml_opt_optimizer {
77+
enum ggml_opt_optimizer_type {
7878
GGML_OPT_OPTIMIZER_ADAMW,
7979
GGML_OPT_OPTIMIZER_SGD,
8080

8181
GGML_OPT_OPTIMIZER_COUNT
8282
};
8383

8484
// "adamw" or "sgd" (case insensitive)
85-
GGML_API const char * ggml_opt_optimizer_name(enum ggml_opt_optimizer);
86-
GGML_API enum ggml_opt_optimizer named_ggml_opt_optimizer(const char *);
85+
GGML_API const char * ggml_opt_optimizer_name(enum ggml_opt_optimizer_type);
86+
GGML_API enum ggml_opt_optimizer_type ggml_opt_get_optimizer(const char *);
8787

8888
// parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
8989
struct ggml_opt_optimizer_params {
90-
// AdamW optimizer parameters
90+
// SGD and AdamW optimizer parameters
91+
struct {
92+
float alpha; // learning rate
93+
float wd; // weight decay for SGD or AdamW, use 0.0f to disable
94+
} common;
95+
9196
struct {
9297
float alpha; // learning rate
93-
float beta1;
94-
float beta2;
95-
float eps; // epsilon for numerical stability
96-
float wd; // weight decay for AdamW, use 0.0f to disable
98+
float beta1; // adamw
99+
float beta2; // adamw
100+
float eps; // epsilon for numerical stability
97101
} adamw;
98-
enum ggml_opt_optimizer optimizer;
102+
103+
// only GGML_OPT_OPTIMIZER_ADAMW allocates m, v per parameter
104+
enum ggml_opt_optimizer_type optimizer;
99105
};
100106

101107
// callback to calculate optimizer parameters prior to a backward pass
@@ -125,7 +131,7 @@ extern "C" {
125131
int32_t opt_period; // after how many gradient accumulation steps an optimizer step should be done
126132

127133
ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
128-
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
134+
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
129135
};
130136

131137
// get parameters for an optimization context with defaults set where possible

ggml/include/ggml.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@ extern "C" {
450450
GGML_OP_REPEAT_BACK,
451451
GGML_OP_CONCAT,
452452
GGML_OP_SILU_BACK,
453-
GGML_OP_NORM, // normalize
453+
GGML_OP_NORM, // normalize
454454
GGML_OP_RMS_NORM,
455455
GGML_OP_RMS_NORM_BACK,
456456
GGML_OP_GROUP_NORM,
@@ -486,7 +486,7 @@ extern "C" {
486486
GGML_OP_POOL_1D,
487487
GGML_OP_POOL_2D,
488488
GGML_OP_POOL_2D_BACK,
489-
GGML_OP_UPSCALE, // nearest interpolate
489+
GGML_OP_UPSCALE, // nearest interpolate
490490
GGML_OP_PAD,
491491
GGML_OP_PAD_REFLECT_1D,
492492
GGML_OP_ARANGE,
@@ -517,6 +517,7 @@ extern "C" {
517517
GGML_OP_CROSS_ENTROPY_LOSS,
518518
GGML_OP_CROSS_ENTROPY_LOSS_BACK,
519519
GGML_OP_OPT_STEP_ADAMW,
520+
GGML_OP_OPT_STEP_SGD,
520521

521522
GGML_OP_COUNT,
522523
};
@@ -2063,6 +2064,12 @@ extern "C" {
20632064
struct ggml_tensor * v,
20642065
struct ggml_tensor * adamw_params); // parameters such a the learning rate
20652066

2067+
// SGD (with weight decay) step
2068+
GGML_API struct ggml_tensor * ggml_opt_step_sgd(struct ggml_context * ctx, struct ggml_tensor * a,
2069+
struct ggml_tensor * grad,
2070+
// parameters: alpha (learning rate), wd (weight decay)
2071+
struct ggml_tensor * adamw_params);
2072+
20662073
//
20672074
// automatic differentiation
20682075
//

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2061,6 +2061,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
20612061
ggml_compute_forward_opt_step_adamw(params, tensor);
20622062
}
20632063
break;
2064+
case GGML_OP_OPT_STEP_SGD:
2065+
{
2066+
ggml_compute_forward_opt_step_sgd(params, tensor);
2067+
}
2068+
break;
20642069
case GGML_OP_NONE:
20652070
{
20662071
// nop
@@ -2345,6 +2350,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
23452350
case GGML_OP_CROSS_ENTROPY_LOSS:
23462351
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
23472352
case GGML_OP_OPT_STEP_ADAMW:
2353+
case GGML_OP_OPT_STEP_SGD:
23482354
{
23492355
n_tasks = n_threads;
23502356
} break;

ggml/src/ggml-cpu/ops.cpp

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8946,7 +8946,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
89468946
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
89478947
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m));
89488948
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v));
8949-
GGML_ASSERT(ggml_nelements(adamw_params) == 7);
8949+
GGML_ASSERT(ggml_nelements(adamw_params) == 8);
89508950

89518951
const int ith = params->ith;
89528952
const int nth = params->nth;
@@ -8964,14 +8964,14 @@ static void ggml_compute_forward_opt_step_adamw_f32(
89648964
const int ir1 = MIN(ir0 + dr, nr);
89658965

89668966
const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);
8967+
89678968
const float alpha = adamw_params_ptr[0];
89688969
const float beta1 = adamw_params_ptr[1];
89698970
const float beta2 = adamw_params_ptr[2];
89708971
const float eps = adamw_params_ptr[3];
8971-
const float wd = adamw_params_ptr[4];
89728972
const float beta1h = adamw_params_ptr[5];
89738973
const float beta2h = adamw_params_ptr[6];
8974-
8974+
const float keep = adamw_params_ptr[7];
89758975
for (int ir = ir0; ir < ir1; ++ir) {
89768976
const int64_t i03 = ir/(ne02*ne01);
89778977
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
@@ -8994,7 +8994,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
89948994
// The weight decay is applied independently of the Adam momenta m and v.
89958995
// This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
89968996
// See: https://arxiv.org/pdf/1711.05101v3.pdf
8997-
w[i00] = w[i00]*(1.0f - alpha*wd) - alpha*mh/vh;
8997+
w[i00] = w[i00] * keep - alpha * mh / vh;
89988998
}
89998999
}
90009000
}
@@ -9016,3 +9016,63 @@ void ggml_compute_forward_opt_step_adamw(
90169016
}
90179017
}
90189018
}
9019+
9020+
static void ggml_compute_forward_opt_step_sgd_f32(const ggml_compute_params * params, ggml_tensor * dst) {
9021+
const ggml_tensor * src0 = dst->src[0];
9022+
const ggml_tensor * src0_grad = dst->src[1];
9023+
const ggml_tensor * adamw_params = dst->src[2];
9024+
9025+
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
9026+
GGML_ASSERT(ggml_nelements(adamw_params) == 8);
9027+
9028+
const int ith = params->ith;
9029+
const int nth = params->nth;
9030+
9031+
const int nr = ggml_nrows(src0);
9032+
9033+
GGML_TENSOR_UNARY_OP_LOCALS
9034+
GGML_ASSERT(nb00 == sizeof(float));
9035+
9036+
// rows per thread
9037+
const int dr = (nr + nth - 1) / nth;
9038+
9039+
// row range for this thread
9040+
const int ir0 = dr * ith;
9041+
const int ir1 = MIN(ir0 + dr, nr);
9042+
9043+
// using adamw param subset we care about - alpha, wd - could have a separate struct
9044+
const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);
9045+
const float alpha = adamw_params_ptr[0];
9046+
const float keep = adamw_params_ptr[7];
9047+
9048+
for (int ir = ir0; ir < ir1; ++ir) {
9049+
const int64_t i03 = ir / (ne02 * ne01);
9050+
const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01;
9051+
const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01);
9052+
9053+
const size_t offset = i03 * nb03 + i02 * nb02 + i01 * nb01;
9054+
9055+
float * w = (float *) ((char *) src0->data + offset); // weight
9056+
const float * g = (const float *) ((const char *) src0_grad->data + offset); // grad
9057+
9058+
for (int i00 = 0; i00 < ne00; ++i00) {
9059+
w[i00] = w[i00] * keep - alpha * g[i00];
9060+
}
9061+
}
9062+
}
9063+
9064+
void ggml_compute_forward_opt_step_sgd(const ggml_compute_params * params, ggml_tensor * dst) {
9065+
const ggml_tensor * src0 = dst->src[0];
9066+
9067+
switch (src0->type) {
9068+
case GGML_TYPE_F32:
9069+
{
9070+
ggml_compute_forward_opt_step_sgd_f32(params, dst);
9071+
}
9072+
break;
9073+
default:
9074+
{
9075+
GGML_ABORT("fatal error - sgd is F32 only");
9076+
}
9077+
}
9078+
}

ggml/src/ggml-cpu/ops.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ void ggml_compute_forward_custom(const struct ggml_compute_params * params, stru
104104
void ggml_compute_forward_cross_entropy_loss(const struct ggml_compute_params * params, struct ggml_tensor * dst);
105105
void ggml_compute_forward_cross_entropy_loss_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
106106
void ggml_compute_forward_opt_step_adamw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
107-
107+
void ggml_compute_forward_opt_step_sgd(const struct ggml_compute_params * params, struct ggml_tensor * dst);
108108
#ifdef __cplusplus
109109
}
110110
#endif

0 commit comments

Comments
 (0)