@@ -69,9 +69,9 @@ extern "C" {
6969    // ====== Model / Context ====== 
7070
7171    enum  ggml_opt_build_type  {
72-         GGML_OPT_BUILD_TYPE_FORWARD ,
73-         GGML_OPT_BUILD_TYPE_GRAD ,
74-         GGML_OPT_BUILD_TYPE_OPT ,
72+         GGML_OPT_BUILD_TYPE_FORWARD   =   10 ,
73+         GGML_OPT_BUILD_TYPE_GRAD      =   20 ,
74+         GGML_OPT_BUILD_TYPE_OPT       =   30 ,
7575    };
7676
7777    // parameters that control which optimizer is used and how said optimizer tries to find the minimal loss 
@@ -101,13 +101,11 @@ extern "C" {
101101    struct  ggml_opt_params  {
102102        ggml_backend_sched_t  backend_sched ; // defines which backends are used to construct the compute graphs 
103103
104-         struct  ggml_context  *  ctx_compute ; // created in user code, holds non-static tensors 
105- 
106-         // the forward graph is defined by inputs and outputs 
107-         // the outputs and all tensors between inputs and outputs that have not been statically allocated 
108-         //     are not intended to be reusable between multiple optimization contexts 
109-         struct  ggml_tensor  *  inputs ;
110-         struct  ggml_tensor  *  outputs ;
104+         // by default the forward graph needs to be reconstructed for each eval 
105+         // if ctx_compute, inputs, and outputs are set the graphs are instead allocated statically 
106+         struct  ggml_context  *  ctx_compute ;
107+         struct  ggml_tensor   *  inputs ;
108+         struct  ggml_tensor   *  outputs ;
111109
112110        enum  ggml_opt_loss_type   loss_type ;
113111        enum  ggml_opt_build_type  build_type ;
@@ -121,11 +119,8 @@ extern "C" {
121119    // get parameters for an optimization context with defaults set where possible 
122120    // parameters for which no sensible defaults exist are supplied as arguments to this function 
123121    GGML_API  struct  ggml_opt_params  ggml_opt_default_params (
124-             ggml_backend_sched_t       backend_sched ,
125-             struct  ggml_context      *  ctx_compute ,
126-             struct  ggml_tensor       *  inputs ,
127-             struct  ggml_tensor       *  outputs ,
128-             enum  ggml_opt_loss_type    loss_type );
122+             ggml_backend_sched_t     backend_sched ,
123+             enum  ggml_opt_loss_type  loss_type );
129124
130125    GGML_API  ggml_opt_context_t  ggml_opt_init (struct  ggml_opt_params  params );
131126    GGML_API  void  ggml_opt_free (ggml_opt_context_t  opt_ctx );
@@ -134,13 +129,15 @@ extern "C" {
134129    GGML_API  void  ggml_opt_reset (ggml_opt_context_t  opt_ctx , bool  optimizer );
135130
136131    // get underlying tensors that store data 
132+     // if not using static graphs these pointers become invalid with the next call to ggml_opt_alloc 
137133    GGML_API  struct  ggml_tensor  *  ggml_opt_inputs (  ggml_opt_context_t  opt_ctx ); // forward graph input tensor 
138134    GGML_API  struct  ggml_tensor  *  ggml_opt_outputs ( ggml_opt_context_t  opt_ctx ); // forward graph output tensor 
139135    GGML_API  struct  ggml_tensor  *  ggml_opt_labels (  ggml_opt_context_t  opt_ctx ); // labels to compare outputs against 
140136    GGML_API  struct  ggml_tensor  *  ggml_opt_loss (    ggml_opt_context_t  opt_ctx ); // scalar tensor that contains the loss 
141137    GGML_API  struct  ggml_tensor  *  ggml_opt_pred (    ggml_opt_context_t  opt_ctx ); // predictions made by outputs 
142138    GGML_API  struct  ggml_tensor  *  ggml_opt_ncorrect (ggml_opt_context_t  opt_ctx ); // number of matching predictions between outputs and labels 
143139
140+     // get the gradient accumulator for a node from the forward graph 
144141    GGML_API  struct  ggml_tensor  *  ggml_opt_grad_acc (ggml_opt_context_t  opt_ctx , struct  ggml_tensor  *  node );
145142
146143    // ====== Optimization Result ====== 
@@ -157,15 +154,20 @@ extern "C" {
157154
158155    // ====== Computation ====== 
159156
160-     GGML_API  void  ggml_opt_set_forward_graph (
161-         ggml_opt_context_t  opt_ctx , struct  ggml_context  *  ctx_compute , struct  ggml_cgraph  *  gf ,
162-         struct  ggml_tensor  *  inputs , struct  ggml_tensor  *  outputs , bool  backward );
157+     // if not using static graphs, this function must be called prior to ggml_opt_alloc 
158+     GGML_API  void  ggml_opt_prepare_alloc (
159+         ggml_opt_context_t     opt_ctx ,
160+         struct  ggml_context  *  ctx_compute ,
161+         struct  ggml_cgraph   *  gf ,
162+         struct  ggml_tensor   *  inputs ,
163+         struct  ggml_tensor   *  outputs );
163164
164-     // do forward pass, increment result if not NULL 
165-     GGML_API  void  ggml_opt_forward (ggml_opt_context_t  opt_ctx , ggml_opt_result_t  result );
165+     // allocate the next graph for evaluation, either forward or forward + backward 
166+     // must be called exactly once prior to calling ggml_opt_eval 
167+     GGML_API  void  ggml_opt_alloc (ggml_opt_context_t  opt_ctx , bool  backward );
166168
167-     // do forward pass, increment result if not NULL, do backward pass 
168-     GGML_API  void  ggml_opt_forward_backward (ggml_opt_context_t  opt_ctx , ggml_opt_result_t  result );
169+     // do forward pass, increment result if not NULL, do backward pass if allocated  
170+     GGML_API  void  ggml_opt_eval (ggml_opt_context_t  opt_ctx , ggml_opt_result_t  result );
169171
170172    // ############################################################################ 
171173    // ## The high-level functions start here. They do not depend on any private ## 
0 commit comments