Skip to content

Commit 5d8b068

Browse files
JohannesGaesslerggerganov
authored andcommitted
llama/ggml: add LLM training support (llama/10544)
* llama/ggml: add LLM training support more compact progress bar llama_save_model_to_file llama_opt_param_filter ggml_graph_dup force_grads refactor ggml_opt, fix test-opt * remove logits_all * refactor CUDA implementation for ACC * reset graph at beginning of opt period
1 parent 93ef226 commit 5d8b068

File tree

7 files changed

+486
-271
lines changed

7 files changed

+486
-271
lines changed

ggml/include/ggml-opt.h

Lines changed: 47 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,16 @@ extern "C" {
3737
// ====== Dataset ======
3838

3939
GGML_API ggml_opt_dataset_t ggml_opt_dataset_init(
40-
int64_t ne_datapoint, // number of elements per datapoint
41-
int64_t ne_label, // number of elements per label
42-
int64_t ndata, // total number of datapoints/labels
43-
int64_t ndata_shard); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied)
40+
enum ggml_type type_data, // the type for the internal data tensor
41+
enum ggml_type type_label, // the type for the internal labels tensor
42+
int64_t ne_datapoint, // number of elements per datapoint
43+
int64_t ne_label, // number of elements per label
44+
int64_t ndata, // total number of datapoints/labels
45+
int64_t ndata_shard); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied)
4446
GGML_API void ggml_opt_dataset_free(ggml_opt_dataset_t dataset);
4547

4648
// get underlying tensors that store the data
49+
GGML_API int64_t ggml_opt_dataset_ndata (ggml_opt_dataset_t dataset);
4750
GGML_API struct ggml_tensor * ggml_opt_dataset_data (ggml_opt_dataset_t dataset); // shape = [ne_datapoint, ndata]
4851
GGML_API struct ggml_tensor * ggml_opt_dataset_labels(ggml_opt_dataset_t dataset); // shape = [nd_label, ndata]
4952

@@ -56,13 +59,19 @@ extern "C" {
5659
struct ggml_tensor * data_batch, // shape = [ne_datapoint, ndata_batch]
5760
struct ggml_tensor * labels_batch, // shape = [ne_label, ndata_batch]
5861
int64_t ibatch);
62+
GGML_API void ggml_opt_dataset_get_batch_host(
63+
ggml_opt_dataset_t dataset,
64+
void * data_batch,
65+
size_t nb_data_batch,
66+
void * labels_batch,
67+
int64_t ibatch);
5968

6069
// ====== Model / Context ======
6170

6271
enum ggml_opt_build_type {
63-
GGML_OPT_BUILD_TYPE_FORWARD,
64-
GGML_OPT_BUILD_TYPE_GRAD,
65-
GGML_OPT_BUILD_TYPE_OPT,
72+
GGML_OPT_BUILD_TYPE_FORWARD = 10,
73+
GGML_OPT_BUILD_TYPE_GRAD = 20,
74+
GGML_OPT_BUILD_TYPE_OPT = 30,
6675
};
6776

6877
// parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
@@ -81,20 +90,22 @@ extern "C" {
8190
// userdata can be used to pass arbitrary data
8291
typedef struct ggml_opt_optimizer_params (*ggml_opt_get_optimizer_params)(void * userdata);
8392

84-
// returns the default optimizer params (constant)
93+
// returns the default optimizer params (constant, hard-coded values)
8594
// userdata is not used
8695
GGML_API struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata);
8796

97+
// casts userdata to ggml_opt_optimizer_params and returns it
98+
GGML_API struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata);
99+
88100
// parameters for initializing a new optimization context
89101
struct ggml_opt_params {
90102
ggml_backend_sched_t backend_sched; // defines which backends are used to construct the compute graphs
91103

92-
struct ggml_context * ctx_compute; // created in user code, holds non-static tensors
93-
94-
// the forward graph is defined by inputs and outputs
95-
// those tensors and all tensors inbetween are not intended to be reusable between multiple optimization contexts
96-
struct ggml_tensor * inputs;
97-
struct ggml_tensor * outputs;
104+
// by default the forward graph needs to be reconstructed for each eval
105+
// if ctx_compute, inputs, and outputs are set the graphs are instead allocated statically
106+
struct ggml_context * ctx_compute;
107+
struct ggml_tensor * inputs;
108+
struct ggml_tensor * outputs;
98109

99110
enum ggml_opt_loss_type loss_type;
100111
enum ggml_opt_build_type build_type;
@@ -107,12 +118,9 @@ extern "C" {
107118

108119
// get parameters for an optimization context with defaults set where possible
109120
// parameters for which no sensible defaults exist are supplied as arguments to this function
110-
GGML_API ggml_opt_params ggml_opt_default_params(
111-
ggml_backend_sched_t backend_sched,
112-
struct ggml_context * ctx_compute,
113-
struct ggml_tensor * inputs,
114-
struct ggml_tensor * outputs,
115-
enum ggml_opt_loss_type loss_type);
121+
GGML_API struct ggml_opt_params ggml_opt_default_params(
122+
ggml_backend_sched_t backend_sched,
123+
enum ggml_opt_loss_type loss_type);
116124

117125
GGML_API ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params);
118126
GGML_API void ggml_opt_free(ggml_opt_context_t opt_ctx);
@@ -121,18 +129,20 @@ extern "C" {
121129
GGML_API void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer);
122130

123131
// get underlying tensors that store data
132+
// if not using static graphs these pointers become invalid with the next call to ggml_opt_alloc
124133
GGML_API struct ggml_tensor * ggml_opt_inputs( ggml_opt_context_t opt_ctx); // forward graph input tensor
125134
GGML_API struct ggml_tensor * ggml_opt_outputs( ggml_opt_context_t opt_ctx); // forward graph output tensor
126135
GGML_API struct ggml_tensor * ggml_opt_labels( ggml_opt_context_t opt_ctx); // labels to compare outputs against
127136
GGML_API struct ggml_tensor * ggml_opt_loss( ggml_opt_context_t opt_ctx); // scalar tensor that contains the loss
128137
GGML_API struct ggml_tensor * ggml_opt_pred( ggml_opt_context_t opt_ctx); // predictions made by outputs
129138
GGML_API struct ggml_tensor * ggml_opt_ncorrect(ggml_opt_context_t opt_ctx); // number of matching predictions between outputs and labels
130139

140+
// get the gradient accumulator for a node from the forward graph
131141
GGML_API struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node);
132142

133143
// ====== Optimization Result ======
134144

135-
GGML_API ggml_opt_result_t ggml_opt_result_init();
145+
GGML_API ggml_opt_result_t ggml_opt_result_init(void);
136146
GGML_API void ggml_opt_result_free(ggml_opt_result_t result);
137147
GGML_API void ggml_opt_result_reset(ggml_opt_result_t result);
138148

@@ -144,11 +154,20 @@ extern "C" {
144154

145155
// ====== Computation ======
146156

147-
// do forward pass, increment result if not NULL
148-
GGML_API void ggml_opt_forward(ggml_opt_context_t opt_ctx, ggml_opt_result_t result);
157+
// if not using static graphs, this function must be called prior to ggml_opt_alloc
158+
GGML_API void ggml_opt_prepare_alloc(
159+
ggml_opt_context_t opt_ctx,
160+
struct ggml_context * ctx_compute,
161+
struct ggml_cgraph * gf,
162+
struct ggml_tensor * inputs,
163+
struct ggml_tensor * outputs);
164+
165+
// allocate the next graph for evaluation, either forward or forward + backward
166+
// must be called exactly once prior to calling ggml_opt_eval
167+
GGML_API void ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward);
149168

150-
// do forward pass, increment result if not NULL, do backward pass
151-
GGML_API void ggml_opt_forward_backward(ggml_opt_context_t opt_ctx, ggml_opt_result_t result);
169+
// do forward pass, increment result if not NULL, do backward pass if allocated
170+
GGML_API void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result);
152171

153172
// ############################################################################
154173
// ## The high-level functions start here. They do not depend on any private ##
@@ -200,9 +219,9 @@ extern "C" {
200219
// fit model defined by inputs and outputs to dataset
201220
GGML_API void ggml_opt_fit(
202221
ggml_backend_sched_t backend_sched, // backend scheduler for constructing the compute graphs
203-
ggml_context * ctx_compute, // context with temporarily allocated tensors to calculate the outputs
204-
ggml_tensor * inputs, // input tensor with shape [ne_datapoint, ndata_batch]
205-
ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
222+
struct ggml_context * ctx_compute, // context with temporarily allocated tensors to calculate the outputs
223+
struct ggml_tensor * inputs, // input tensor with shape [ne_datapoint, ndata_batch]
224+
struct ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
206225
ggml_opt_dataset_t dataset, // dataset with data and optionally also labels
207226
enum ggml_opt_loss_type loss_type, // loss to minimize
208227
ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t)

ggml/include/ggml.h

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -768,7 +768,7 @@ extern "C" {
768768
// Tensor flags
769769
GGML_API void ggml_set_input(struct ggml_tensor * tensor);
770770
GGML_API void ggml_set_output(struct ggml_tensor * tensor);
771-
GGML_API void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor);
771+
GGML_API void ggml_set_param(struct ggml_tensor * tensor);
772772
GGML_API void ggml_set_loss(struct ggml_tensor * tensor);
773773

774774
//
@@ -938,7 +938,7 @@ extern "C" {
938938
GGML_API struct ggml_tensor * ggml_repeat_back(
939939
struct ggml_context * ctx,
940940
struct ggml_tensor * a,
941-
struct ggml_tensor * b);
941+
struct ggml_tensor * b); // sum up values that are adjacent in dims > 0 instead of repeated with same stride
942942

943943
// concat a and b along dim
944944
// used in stable-diffusion
@@ -2049,15 +2049,14 @@ extern "C" {
20492049

20502050
GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
20512051
GGML_API void ggml_build_backward_expand(
2052-
struct ggml_context * ctx_static, // context for static gradients (loss + gradient accumulation)
2053-
struct ggml_context * ctx_compute, // context for gradient computation
2054-
struct ggml_cgraph * cgraph,
2055-
bool accumulate); // whether or not gradients should be accumulated, requires static allocation of tensors in ctx_static
2052+
struct ggml_context * ctx, // context for gradient computation
2053+
struct ggml_cgraph * cgraph,
2054+
struct ggml_tensor ** grad_accs);
20562055

20572056
// graph allocation in a context
20582057
GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false
20592058
GGML_API struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t size, bool grads);
2060-
GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph);
2059+
GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph, bool force_grads);
20612060
GGML_API void ggml_graph_cpy (struct ggml_cgraph * src, struct ggml_cgraph * dst);
20622061
GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // set regular grads + optimizer momenta to 0, set loss grad to 1
20632062
GGML_API void ggml_graph_clear (struct ggml_cgraph * cgraph);

ggml/src/ggml-backend.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1111,7 +1111,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
11111111

11121112
const int node_backend_id = tensor_backend_id(node);
11131113

1114-
assert(node_backend_id != -1); // all nodes should be assigned by now
1114+
assert(node_backend_id != -1); // all nodes should be assigned by now, this can happen if there is no CPU fallback
11151115

11161116
// check if we should start a new split based on the sources of the current node
11171117
bool need_new_split = false;

ggml/src/ggml-cuda/acc.cu

Lines changed: 40 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,61 @@
11
#include "acc.cuh"
22

3-
static __global__ void acc_f32(const float * x, const float * y, float * dst, const int ne,
4-
const int ne10, const int ne11, const int ne12,
5-
const int nb1, const int nb2, int offset) {
6-
const int i = blockDim.x * blockIdx.x + threadIdx.x;
3+
static __global__ void acc_f32(const float * x, const float * y, float * dst, const int64_t ne,
4+
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
5+
const int64_t s11, const int64_t s12, const int64_t s13, const int64_t offset) {
6+
const int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
7+
78
if (i >= ne) {
89
return;
910
}
10-
int src1_idx = i - offset;
11-
int oz = src1_idx / nb2;
12-
int oy = (src1_idx - (oz * nb2)) / nb1;
13-
int ox = src1_idx % nb1;
14-
if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) {
15-
dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11];
16-
} else {
17-
dst[i] = x[i];
11+
12+
int64_t src1_idx = i - offset;
13+
14+
int64_t tmp = src1_idx;
15+
const int64_t i13 = tmp / s13;
16+
tmp -= i13 * s13;
17+
const int64_t i12 = tmp / s12;
18+
tmp -= i12 * s12;
19+
const int64_t i11 = tmp / s11;
20+
tmp -= i11 * s11;
21+
const int64_t i10 = tmp;
22+
23+
float val = x[i];
24+
if (src1_idx >= 0 && i10 < ne10 && i11 < ne11 && i12 < ne12 && i13 < ne13) {
25+
val += y[((i13*ne12 + i12) * ne11 + i11) * ne10 + i10];
1826
}
27+
dst[i] = val;
1928
}
2029

21-
static void acc_f32_cuda(const float * x, const float * y, float * dst, const int n_elements,
22-
const int ne10, const int ne11, const int ne12,
23-
const int nb1, const int nb2, const int offset, cudaStream_t stream) {
24-
int num_blocks = (n_elements + CUDA_ACC_BLOCK_SIZE - 1) / CUDA_ACC_BLOCK_SIZE;
25-
acc_f32<<<num_blocks, CUDA_ACC_BLOCK_SIZE, 0, stream>>>(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset);
30+
static void acc_f32_cuda(const float * x, const float * y, float * dst, const int64_t n_elements,
31+
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
32+
const int64_t s1, const int64_t s2, const int64_t s3, const int64_t offset, cudaStream_t stream) {
33+
const int num_blocks = (n_elements + CUDA_ACC_BLOCK_SIZE - 1) / CUDA_ACC_BLOCK_SIZE;
34+
acc_f32<<<num_blocks, CUDA_ACC_BLOCK_SIZE, 0, stream>>>(x, y, dst, n_elements, ne10, ne11, ne12, ne13, s1, s2, s3, offset);
2635
}
2736

2837
void ggml_cuda_op_acc(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
2938
const ggml_tensor * src0 = dst->src[0];
3039
const ggml_tensor * src1 = dst->src[1];
31-
const float * src0_d = (const float *)src0->data;
32-
const float * src1_d = (const float *)src1->data;
33-
float * dst_d = (float *)dst->data;
40+
41+
const float * src0_d = (const float *) src0->data;
42+
const float * src1_d = (const float *) src1->data;
43+
float * dst_d = (float *) dst->data;
44+
3445
cudaStream_t stream = ctx.stream();
3546

3647
GGML_ASSERT(src0->type == GGML_TYPE_F32);
3748
GGML_ASSERT(src1->type == GGML_TYPE_F32);
3849
GGML_ASSERT( dst->type == GGML_TYPE_F32);
39-
GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported
4050

41-
int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
42-
int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
43-
// int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
44-
int offset = dst->op_params[3] / 4; // offset in bytes
51+
GGML_ASSERT(ggml_is_contiguous(src1));
52+
GGML_ASSERT(dst->nb[0] == ggml_element_size(dst));
53+
GGML_ASSERT(ggml_is_contiguously_allocated(dst));
54+
55+
const int64_t s1 = dst->op_params[0] / sizeof(float);
56+
const int64_t s2 = dst->op_params[1] / sizeof(float);
57+
const int64_t s3 = dst->op_params[2] / sizeof(float);
58+
const int64_t offset = dst->op_params[3] / sizeof(float);
4559

46-
acc_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], nb1, nb2, offset, stream);
60+
acc_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], s1, s2, s3, offset, stream);
4761
}

ggml/src/ggml-cuda/sum.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ void ggml_cuda_op_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
3131

3232
GGML_ASSERT(src0->type == GGML_TYPE_F32);
3333
GGML_ASSERT( dst->type == GGML_TYPE_F32);
34-
GGML_ASSERT(ggml_is_contiguous(src0));
34+
GGML_ASSERT(ggml_is_contiguously_allocated(src0));
3535

3636
const float * src0_d = (const float *) src0->data;
3737
float * dst_d = (float *) dst->data;

0 commit comments

Comments
 (0)