1414#include < vector>
1515
1616struct ggml_opt_dataset {
17- struct ggml_context * ctx;
18- ggml_backend_buffer_t buf;
19- struct ggml_tensor * data;
20- struct ggml_tensor * labels;
17+ struct ggml_context * ctx = nullptr ;
18+ ggml_backend_buffer_t buf = nullptr ;
19+ struct ggml_tensor * data = nullptr ;
20+ struct ggml_tensor * labels = nullptr ;
2121
22- int64_t ndata;
23- int64_t ndata_shard;
24- size_t nbs_data;
25- size_t nbs_labels;
22+ int64_t ndata = - 1 ;
23+ int64_t ndata_shard = - 1 ;
24+ size_t nbs_data = - 1 ;
25+ size_t nbs_labels = - 1 ;
2626
2727 std::vector<int64_t > permutation;
2828};
2929
3030struct ggml_opt_context {
31- ggml_backend_sched_t backend_sched;
32- ggml_cgraph * allocated_graph;
33- ggml_cgraph * allocated_graph_copy;
34- struct ggml_context * ctx_static;
35- struct ggml_context * ctx_static_cpu;
36- struct ggml_context * ctx_compute;
37- struct ggml_context * ctx_copy;
38- ggml_backend_buffer_t buf_static;
39- ggml_backend_buffer_t buf_static_cpu;
31+ ggml_backend_sched_t backend_sched = nullptr ;
32+ ggml_cgraph * allocated_graph = nullptr ;
33+ ggml_cgraph * allocated_graph_copy = nullptr ;
34+ struct ggml_context * ctx_static = nullptr ;
35+ struct ggml_context * ctx_static_cpu = nullptr ;
36+ struct ggml_context * ctx_compute = nullptr ;
37+ struct ggml_context * ctx_copy = nullptr ;
38+ ggml_backend_buffer_t buf_static = nullptr ;
39+ ggml_backend_buffer_t buf_static_cpu = nullptr ;
4040 std::mt19937 rng;
4141
42- struct ggml_tensor * inputs;
43- struct ggml_tensor * outputs;
44- struct ggml_tensor * labels;
42+ struct ggml_tensor * inputs = nullptr ;
43+ struct ggml_tensor * outputs = nullptr ;
44+ struct ggml_tensor * labels = nullptr ;
4545
46- struct ggml_tensor * loss;
47- struct ggml_tensor * pred;
48- struct ggml_tensor * ncorrect;
46+ struct ggml_tensor * loss = nullptr ;
47+ struct ggml_tensor * pred = nullptr ;
48+ struct ggml_tensor * ncorrect = nullptr ;
4949
50- struct ggml_cgraph * gf;
51- struct ggml_cgraph * gb_grad;
52- struct ggml_cgraph * gb_opt;
50+ struct ggml_cgraph * gf = nullptr ;
51+ struct ggml_cgraph * gb_grad = nullptr ;
52+ struct ggml_cgraph * gb_opt = nullptr ;
5353
54- int64_t iter;
55- int32_t opt_period;
56- int32_t opt_i;
57- bool loss_per_datapoint;
54+ int64_t iter = 1 ;
55+ int32_t opt_period = 1 ;
56+ int32_t opt_i = 0 ;
57+ bool loss_per_datapoint = false ;
5858
59- ggml_opt_get_optimizer_params get_opt_pars;
60- void * get_opt_pars_ud;
61- struct ggml_tensor * adamw_params;
59+ ggml_opt_get_optimizer_params get_opt_pars = nullptr ;
60+ void * get_opt_pars_ud = nullptr ;
61+ struct ggml_tensor * adamw_params = nullptr ;
6262};
6363
6464struct ggml_opt_result {
@@ -67,8 +67,8 @@ struct ggml_opt_result {
6767 std::vector<int32_t > pred;
6868 int64_t ncorrect = 0 ;
6969
70- bool loss_per_datapoint = false ;
71- int64_t opt_period = - 1 ;
70+ int64_t opt_period = - 1 ;
71+ bool loss_per_datapoint = false ;
7272};
7373
7474// ====== Dataset ======
@@ -188,11 +188,11 @@ struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * us
188188}
189189
190190struct ggml_opt_params ggml_opt_default_params (
191- ggml_backend_sched_t backend_sched,
192- struct ggml_context * ctx_compute,
193- struct ggml_tensor * inputs,
194- struct ggml_tensor * outputs,
195- enum ggml_opt_loss_type loss_type) {
191+ ggml_backend_sched_t backend_sched,
192+ struct ggml_context * ctx_compute,
193+ struct ggml_tensor * inputs,
194+ struct ggml_tensor * outputs,
195+ enum ggml_opt_loss_type loss_type) {
196196 return {
197197 /* backend_sched =*/ backend_sched,
198198 /* ctx_compute =*/ ctx_compute,
@@ -237,25 +237,33 @@ static ggml_tensor * map_tensor(std::map<ggml_tensor *, ggml_tensor *> & tensor_
237237 return new_tensor;
238238}
239239
240- static ggml_cgraph * dup_graph (ggml_context * ctx, ggml_cgraph * graph ) {
240+ static ggml_cgraph * dup_graph (ggml_context * ctx, ggml_cgraph * src ) {
241241 std::map<ggml_tensor *, ggml_tensor *> tensor_map;
242242
243- ggml_cgraph * new_graph = ggml_new_graph_custom (ctx, GGML_DEFAULT_GRAPH_SIZE , /* grads =*/ true );
243+ ggml_cgraph * dst = ggml_new_graph_custom (ctx, src-> size , /* grads =*/ true );
244244
245- for (int i = 0 ; i < graph ->n_leafs ; i++) {
246- ggml_build_forward_expand (new_graph , map_tensor (tensor_map, ctx, graph ->leafs [i]));
245+ for (int i = 0 ; i < src ->n_leafs ; i++) {
246+ ggml_build_forward_expand (dst , map_tensor (tensor_map, ctx, src ->leafs [i]));
247247 }
248- for (int i = 0 ; i < graph->n_nodes ; i++) {
249- ggml_build_forward_expand (new_graph, map_tensor (tensor_map, ctx, graph->nodes [i]));
248+ GGML_ASSERT (dst->n_leafs == src->n_leafs );
249+ for (int i = 0 ; i < src->n_nodes ; i++) {
250+ ggml_build_forward_expand (dst, map_tensor (tensor_map, ctx, src->nodes [i]));
250251 }
251- for (int i = 0 ; i < graph->n_nodes ; ++i) {
252- const size_t igrad_src = ggml_hash_find (&graph->visited_hash_set , graph->nodes [i]);
253- const size_t igrad_dst = ggml_hash_find (&new_graph->visited_hash_set , new_graph->nodes [i]);
254- graph->grads [igrad_dst] = new_graph->grads [igrad_src];
255- graph->grad_accs [igrad_dst] = new_graph->grad_accs [igrad_src];
252+ GGML_ASSERT (dst->n_nodes == src->n_nodes );
253+ for (int i = 0 ; i < src->n_nodes ; ++i) {
254+ const size_t igrad_src = ggml_hash_find (&src->visited_hash_set , src->nodes [i]);
255+ const size_t igrad_dst = ggml_hash_find (&dst->visited_hash_set , dst->nodes [i]);
256+
257+ GGML_ASSERT (igrad_src != GGML_HASHSET_FULL);
258+ GGML_ASSERT (ggml_bitset_get (src->visited_hash_set .used , igrad_src));
259+ GGML_ASSERT (igrad_dst != GGML_HASHSET_FULL);
260+ GGML_ASSERT (ggml_bitset_get (dst->visited_hash_set .used , igrad_dst));
261+
262+ dst->grads [igrad_dst] = src->grads [igrad_src];
263+ dst->grad_accs [igrad_dst] = src->grad_accs [igrad_src];
256264 }
257265
258- return new_graph ;
266+ return dst ;
259267}
260268
261269static void ggml_opt_alloc_graph (ggml_opt_context_t opt_ctx, ggml_cgraph * graph) {
@@ -284,18 +292,13 @@ static void ggml_opt_alloc_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph
284292
285293ggml_opt_context_t ggml_opt_init (struct ggml_opt_params params) {
286294 ggml_opt_context_t result = new struct ggml_opt_context ;
287- result->backend_sched = params.backend_sched ;
288- result->allocated_graph = nullptr ;
289- result->allocated_graph_copy = nullptr ;
290- result->ctx_compute = params.ctx_compute ;
291- result->ctx_copy = nullptr ;
292- result->inputs = params.inputs ;
293- result->outputs = params.outputs ;
294- result->iter = 1 ;
295- result->opt_period = params.opt_period ;
296- result->opt_i = 0 ;
297- result->get_opt_pars = params.get_opt_pars ;
298- result->get_opt_pars_ud = params.get_opt_pars_ud ;
295+ result->backend_sched = params.backend_sched ;
296+ result->ctx_compute = params.ctx_compute ;
297+ result->inputs = params.inputs ;
298+ result->outputs = params.outputs ;
299+ result->opt_period = params.opt_period ;
300+ result->get_opt_pars = params.get_opt_pars ;
301+ result->get_opt_pars_ud = params.get_opt_pars_ud ;
299302
300303 GGML_ASSERT (result->inputs ->data && " the inputs must be allocated statically" );
301304 GGML_ASSERT (result->opt_period >= 1 );
@@ -348,7 +351,6 @@ ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
348351
349352 switch (params.loss_type ) {
350353 case GGML_OPT_LOSS_TYPE_MEAN: {
351- result->labels = nullptr ;
352354 result->loss = ggml_sum (result->ctx_static , result->outputs );
353355 ggml_set_name (result->loss , " loss_sum" );
354356 const float scale = 1 .0f / (result->opt_period * ggml_nelements (result->outputs ));
@@ -358,7 +360,6 @@ ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
358360 break ;
359361 }
360362 case GGML_OPT_LOSS_TYPE_SUM: {
361- result->labels = nullptr ;
362363 result->loss = ggml_sum (result->ctx_static , result->outputs );
363364 ggml_set_name (result->loss , " loss_sum" );
364365 result->loss_per_datapoint = false ;
@@ -413,14 +414,7 @@ ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
413414 }
414415
415416 if (params.build_type == GGML_OPT_BUILD_TYPE_FORWARD) {
416- result->gb_grad = nullptr ;
417- result->gb_opt = nullptr ;
418-
419417 result->buf_static = ggml_backend_alloc_ctx_tensors (result->ctx_static , ggml_backend_sched_get_backend (result->backend_sched , 0 ));
420- result->buf_static_cpu = nullptr ;
421-
422- ggml_opt_alloc_graph (result, result->gf );
423-
424418 return result;
425419 }
426420
@@ -429,14 +423,8 @@ ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
429423 ggml_build_backward_expand (result->ctx_static , result->ctx_compute , result->gb_grad , accumulate);
430424
431425 if (params.build_type == GGML_OPT_BUILD_TYPE_GRAD) {
432- result->gb_opt = nullptr ;
433-
434426 result->buf_static = ggml_backend_alloc_ctx_tensors (result->ctx_static , ggml_backend_sched_get_backend (result->backend_sched , 0 ));
435- result->buf_static_cpu = nullptr ;
436-
437- ggml_opt_alloc_graph (result, result->gb_grad );
438427 ggml_graph_reset (result->gb_grad );
439-
440428 return result;
441429 }
442430
@@ -466,7 +454,6 @@ ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
466454
467455 result->buf_static_cpu = ggml_backend_alloc_ctx_tensors_from_buft (result->ctx_static_cpu , ggml_backend_cpu_buffer_type ());
468456
469- ggml_opt_alloc_graph (result, result->gb_opt );
470457 ggml_graph_reset (result->gb_opt );
471458
472459 return result;
0 commit comments