Skip to content

Commit 75e5564

Browse files
graehl0cc4mJohannesGaessler
authored andcommitted
finetune: SGD optimizer, more CLI args (llama/13873)
* examples/finetune -opt SGD (stochastic gradient descent) memory opt add unit tested GGML_OPT_OPTIMIZER_SGD to ggml - avoids allocating m, v tensors. support finetune.cpp arg -opt SGD (or sgd). (default adamw as before) llama 3.2-1b-F32 result: observed 11gb gpu ram (41 sec/epoch) when using SGD instead of 19gb (55 sec/epoch) using adamw. (wikipedia 100 lines finetune) ( using the same GPU memory, adamw can only do before OOM 512 batch/context, reaching: train: [███████▉] data=0000140/0000140 loss=0.02575±0.00099 acc=99.52±0.03% t=00:00:47 ETA=00:00:00 val: [███████▉] data=0000008/0000008 loss=4.76565±0.28810 acc=41.46±0.77% t=00:00:00 ETA=00:00:00 SGD is superior, though it converges slower, with max before OOM 1728 batch/context (esp see the better validation perf): train: [███████▉] data=0000039/0000039 loss=0.00371±0.00010 acc=99.96±0.01% t=00:00:41 ETA=00:00:00 val: [███████▉] data=0000003/0000003 loss=5.11406±0.76034 acc=48.01±0.69% t=00:00:01 ETA=00:00:00 ) note: when finetuning long enough (or w/ enough -lr), validation accuracy *eventually* drops ('catastrophic forgetting') -lr-half (halflife) option useful for SGD to avoid oscillation or super slow underdamped learning (makes setting -lr more forgiving). terminal -lr for now is set by lr-halvings i.e. if you want at most 1/8 the inital -lr you set -lr-halvings 3. note: objective loss not directly comparable between adamw, sgd? - check perplexity or accuracy or consider relative improvements for convergence new finetune args -wd 1e-9 to enable weight decay in sgd or adamw, and max -epochs N (default 2 as before) cache (1 - wd*alpha) in 'adamw' opt struct - no noticeable perf benefit, disabled (still done for new SGD though) since opt. memory is pre-allocated, the ggml_opt_get_optimizer_params would probably be able to change between SGD and AdamW with each epoch but would need to use adamw for the first (unconfirmed - no cmdline arg to set such a policy yet) test-opt checks adamw as before and now sgd (except for a few disabled tests for sgd only; probably just needs logging values and adding alternate reference values); tolerance on the 'regression' test is broader for sgd (so we don't need many more epochs) * Vulkan: Implement GGML_OP_OPT_STEP_SGD * tests: Fix OPT_STEP_SGD test-backend-ops * SGD op param store weight-decay and not 1-alpha*wd * minor + cosmetic changes * fix vulkan sgd * try CI fix --------- Co-authored-by: 0cc4m <[email protected]> Co-authored-by: Johannes Gäßler <[email protected]>
1 parent 78d736a commit 75e5564

File tree

15 files changed

+556
-165
lines changed

15 files changed

+556
-165
lines changed

include/ggml-opt.h

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,16 +74,26 @@ extern "C" {
7474
GGML_OPT_BUILD_TYPE_OPT = 30,
7575
};
7676

77+
enum ggml_opt_optimizer_type {
78+
GGML_OPT_OPTIMIZER_TYPE_ADAMW,
79+
GGML_OPT_OPTIMIZER_TYPE_SGD,
80+
81+
GGML_OPT_OPTIMIZER_TYPE_COUNT
82+
};
83+
7784
// parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
7885
struct ggml_opt_optimizer_params {
79-
// AdamW optimizer parameters
8086
struct {
8187
float alpha; // learning rate
82-
float beta1;
83-
float beta2;
88+
float beta1; // first AdamW momentum
89+
float beta2; // second AdamW momentum
8490
float eps; // epsilon for numerical stability
85-
float wd; // weight decay for AdamW, use 0.0f to disable
91+
float wd; // weight decay - 0.0f to disable
8692
} adamw;
93+
struct {
94+
float alpha; // learning rate
95+
float wd; // weight decay
96+
} sgd;
8797
};
8898

8999
// callback to calculate optimizer parameters prior to a backward pass
@@ -112,8 +122,11 @@ extern "C" {
112122

113123
int32_t opt_period; // after how many gradient accumulation steps an optimizer step should be done
114124

115-
ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
116-
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
125+
ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
126+
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
127+
128+
// only GGML_OPT_OPTIMIZER_TYPE_ADAMW needs m, v momenta per parameter tensor
129+
enum ggml_opt_optimizer_type optimizer;
117130
};
118131

119132
// get parameters for an optimization context with defaults set where possible
@@ -142,6 +155,10 @@ extern "C" {
142155
// get the gradient accumulator for a node from the forward graph
143156
GGML_API struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node);
144157

158+
GGML_API enum ggml_opt_optimizer_type ggml_opt_context_optimizer_type(ggml_opt_context_t); //TODO consistent naming scheme
159+
160+
GGML_API const char * ggml_opt_optimizer_name(enum ggml_opt_optimizer_type);
161+
145162
// ====== Optimization Result ======
146163

147164
GGML_API ggml_opt_result_t ggml_opt_result_init(void);
@@ -226,12 +243,14 @@ extern "C" {
226243
struct ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
227244
ggml_opt_dataset_t dataset, // dataset with data and optionally also labels
228245
enum ggml_opt_loss_type loss_type, // loss to minimize
246+
enum ggml_opt_optimizer_type optimizer, // sgd or adamw
229247
ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t)
230248
int64_t nepoch, // how many times the dataset should be iterated over
231249
int64_t nbatch_logical, // datapoints optimizer step, must be a multiple of ndata_batch in inputs/outputs
232250
float val_split, // fraction of the dataset to use for validation, must be in [0.0f, 1.0f)
233251
bool silent); // whether or not info prints to stderr should be suppressed
234252

253+
235254
#ifdef __cplusplus
236255
}
237256
#endif

include/ggml.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,7 @@ extern "C" {
542542
GGML_OP_CROSS_ENTROPY_LOSS,
543543
GGML_OP_CROSS_ENTROPY_LOSS_BACK,
544544
GGML_OP_OPT_STEP_ADAMW,
545+
GGML_OP_OPT_STEP_SGD,
545546

546547
GGML_OP_GLU,
547548

@@ -2311,7 +2312,14 @@ extern "C" {
23112312
struct ggml_tensor * grad,
23122313
struct ggml_tensor * m,
23132314
struct ggml_tensor * v,
2314-
struct ggml_tensor * adamw_params); // parameters such a the learning rate
2315+
struct ggml_tensor * adamw_params); // parameters such as the learning rate
2316+
2317+
// stochastic gradient descent step (with weight decay)
2318+
GGML_API struct ggml_tensor * ggml_opt_step_sgd(
2319+
struct ggml_context * ctx,
2320+
struct ggml_tensor * a,
2321+
struct ggml_tensor * grad,
2322+
struct ggml_tensor * sgd_params); // alpha, weight decay
23152323

23162324
//
23172325
// automatic differentiation

src/ggml-cpu/ggml-cpu.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2022,6 +2022,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
20222022
ggml_compute_forward_opt_step_adamw(params, tensor);
20232023
}
20242024
break;
2025+
case GGML_OP_OPT_STEP_SGD:
2026+
{
2027+
ggml_compute_forward_opt_step_sgd(params, tensor);
2028+
}
2029+
break;
20252030
case GGML_OP_NONE:
20262031
{
20272032
// nop
@@ -2325,6 +2330,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
23252330
case GGML_OP_CROSS_ENTROPY_LOSS:
23262331
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
23272332
case GGML_OP_OPT_STEP_ADAMW:
2333+
case GGML_OP_OPT_STEP_SGD:
23282334
{
23292335
n_tasks = n_threads;
23302336
} break;

src/ggml-cpu/ops.cpp

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10330,14 +10330,15 @@ static void ggml_compute_forward_opt_step_adamw_f32(
1033010330
const int ir1 = MIN(ir0 + dr, nr);
1033110331

1033210332
const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);
10333+
1033310334
const float alpha = adamw_params_ptr[0];
1033410335
const float beta1 = adamw_params_ptr[1];
1033510336
const float beta2 = adamw_params_ptr[2];
1033610337
const float eps = adamw_params_ptr[3];
1033710338
const float wd = adamw_params_ptr[4];
1033810339
const float beta1h = adamw_params_ptr[5];
1033910340
const float beta2h = adamw_params_ptr[6];
10340-
10341+
const float keep = 1.f - alpha * wd;
1034110342
for (int ir = ir0; ir < ir1; ++ir) {
1034210343
const int64_t i03 = ir/(ne02*ne01);
1034310344
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
@@ -10360,7 +10361,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
1036010361
// The weight decay is applied independently of the Adam momenta m and v.
1036110362
// This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
1036210363
// See: https://arxiv.org/pdf/1711.05101v3.pdf
10363-
w[i00] = w[i00]*(1.0f - alpha*wd) - alpha*mh/vh;
10364+
w[i00] = w[i00] * keep - alpha * mh / vh;
1036410365
}
1036510366
}
1036610367
}
@@ -10382,3 +10383,63 @@ void ggml_compute_forward_opt_step_adamw(
1038210383
}
1038310384
}
1038410385
}
10386+
10387+
static void ggml_compute_forward_opt_step_sgd_f32(const ggml_compute_params * params, ggml_tensor * dst) {
10388+
const ggml_tensor * src0 = dst->src[0];
10389+
const ggml_tensor * src0_grad = dst->src[1];
10390+
const ggml_tensor * sgd_params = dst->src[2];
10391+
10392+
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
10393+
GGML_ASSERT(ggml_nelements(sgd_params) == 2);
10394+
10395+
const int ith = params->ith;
10396+
const int nth = params->nth;
10397+
10398+
const int nr = ggml_nrows(src0);
10399+
10400+
GGML_TENSOR_UNARY_OP_LOCALS
10401+
GGML_ASSERT(nb00 == sizeof(float));
10402+
10403+
// rows per thread
10404+
const int dr = (nr + nth - 1) / nth;
10405+
10406+
// row range for this thread
10407+
const int ir0 = dr * ith;
10408+
const int ir1 = MIN(ir0 + dr, nr);
10409+
10410+
// using adamw param subset we care about - alpha, wd - could have a separate struct
10411+
const float * sgd_params_ptr = ggml_get_data_f32(sgd_params);
10412+
const float alpha = sgd_params_ptr[0];
10413+
const float keep = 1.f - alpha * sgd_params_ptr[1];
10414+
10415+
for (int ir = ir0; ir < ir1; ++ir) {
10416+
const int64_t i03 = ir / (ne02 * ne01);
10417+
const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01;
10418+
const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01);
10419+
10420+
const size_t offset = i03 * nb03 + i02 * nb02 + i01 * nb01;
10421+
10422+
float * w = (float *) ((char *) src0->data + offset); // weight
10423+
const float * g = (const float *) ((const char *) src0_grad->data + offset); // grad
10424+
10425+
for (int i00 = 0; i00 < ne00; ++i00) {
10426+
w[i00] = w[i00] * keep - alpha * g[i00];
10427+
}
10428+
}
10429+
}
10430+
10431+
void ggml_compute_forward_opt_step_sgd(const ggml_compute_params * params, ggml_tensor * dst) {
10432+
const ggml_tensor * src0 = dst->src[0];
10433+
10434+
switch (src0->type) {
10435+
case GGML_TYPE_F32:
10436+
{
10437+
ggml_compute_forward_opt_step_sgd_f32(params, dst);
10438+
}
10439+
break;
10440+
default:
10441+
{
10442+
GGML_ABORT("fatal error - sgd is F32 only");
10443+
}
10444+
}
10445+
}

src/ggml-cpu/ops.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ void ggml_compute_forward_cross_entropy_loss(const struct ggml_compute_params *
107107
void ggml_compute_forward_cross_entropy_loss_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
108108
void ggml_compute_forward_opt_step_adamw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
109109
void ggml_compute_forward_mul_mat(const struct ggml_compute_params * params, struct ggml_tensor * dst);
110-
110+
void ggml_compute_forward_opt_step_sgd(const struct ggml_compute_params * params, struct ggml_tensor * dst);
111111
#ifdef __cplusplus
112112
}
113113
#endif

src/ggml-cuda/ggml-cuda.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "ggml-cuda/mmvq.cuh"
2929
#include "ggml-cuda/norm.cuh"
3030
#include "ggml-cuda/opt-step-adamw.cuh"
31+
#include "ggml-cuda/opt-step-sgd.cuh"
3132
#include "ggml-cuda/out-prod.cuh"
3233
#include "ggml-cuda/pad.cuh"
3334
#include "ggml-cuda/pool2d.cuh"
@@ -2479,6 +2480,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
24792480
case GGML_OP_OPT_STEP_ADAMW:
24802481
ggml_cuda_opt_step_adamw(ctx, dst);
24812482
break;
2483+
case GGML_OP_OPT_STEP_SGD:
2484+
ggml_cuda_opt_step_sgd(ctx, dst);
2485+
break;
24822486
default:
24832487
return false;
24842488
}
@@ -3536,6 +3540,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
35363540
case GGML_OP_CROSS_ENTROPY_LOSS:
35373541
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
35383542
case GGML_OP_OPT_STEP_ADAMW:
3543+
case GGML_OP_OPT_STEP_SGD:
35393544
return true;
35403545
default:
35413546
return false;

src/ggml-cuda/opt-step-sgd.cu

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#include "ggml-impl.h"
2+
#include "opt-step-sgd.cuh"
3+
4+
#include <cstdint>
5+
6+
static __global__ void opt_step_sgd_f32(
7+
float * __restrict__ x, const float * __restrict__ g,
8+
const float * __restrict__ pars, const int64_t k) {
9+
10+
const int64_t i = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;
11+
12+
if (i >= k) {
13+
return;
14+
}
15+
x[i] = x[i] * (1.0f - pars[0] * pars[1]) - pars[0] * g[i];
16+
}
17+
18+
static void opt_step_sgd_f32_cuda(
19+
float * x, const float * g, const float * __restrict__ pars, const int64_t k, cudaStream_t stream) {
20+
21+
const dim3 block_dims(CUDA_OPT_STEP_SGD_BLOCK_SIZE, 1, 1);
22+
const dim3 block_nums((k + CUDA_OPT_STEP_SGD_BLOCK_SIZE - 1) / CUDA_OPT_STEP_SGD_BLOCK_SIZE, 1, 1);
23+
opt_step_sgd_f32<<<block_nums, block_dims, 0, stream>>>(x, g, pars, k);
24+
}
25+
26+
void ggml_cuda_opt_step_sgd(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
27+
const ggml_tensor * src0 = dst->src[0];
28+
const ggml_tensor * src0_grad = dst->src[1];
29+
const ggml_tensor * params = dst->src[2];
30+
31+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
32+
GGML_ASSERT(src0_grad->type == GGML_TYPE_F32);
33+
GGML_ASSERT(params->type == GGML_TYPE_F32);
34+
GGML_ASSERT(ggml_is_contiguous(src0));
35+
GGML_ASSERT(ggml_is_contiguous(src0_grad));
36+
GGML_ASSERT(ggml_is_contiguous(params));
37+
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
38+
GGML_ASSERT(ggml_nelements(params) == 2);
39+
40+
float * src0_d = (float *) src0->data;
41+
const float * src0_grad_d = (const float *) src0_grad->data;
42+
const float * params_d = (const float *) params->data;
43+
44+
cudaStream_t stream = ctx.stream();
45+
46+
const int64_t ne = ggml_nelements(src0);
47+
48+
opt_step_sgd_f32_cuda(src0_d, src0_grad_d, params_d, ne, stream);
49+
}

src/ggml-cuda/opt-step-sgd.cuh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#include "common.cuh"
2+
3+
#define CUDA_OPT_STEP_SGD_BLOCK_SIZE 256
4+
5+
void ggml_cuda_opt_step_sgd(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

0 commit comments

Comments
 (0)