@@ -37,13 +37,16 @@ extern "C" {
3737    // ====== Dataset ====== 
3838
3939    GGML_API  ggml_opt_dataset_t  ggml_opt_dataset_init (
40-             int64_t  ne_datapoint , // number of elements per datapoint 
41-             int64_t  ne_label ,     // number of elements per label 
42-             int64_t  ndata ,        // total number of datapoints/labels 
43-             int64_t  ndata_shard ); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied) 
40+             enum  ggml_type  type_data ,    // the type for the internal data tensor 
41+             enum  ggml_type  type_label ,   // the type for the internal labels tensor 
42+             int64_t         ne_datapoint , // number of elements per datapoint 
43+             int64_t         ne_label ,     // number of elements per label 
44+             int64_t         ndata ,        // total number of datapoints/labels 
45+             int64_t         ndata_shard ); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied) 
4446    GGML_API  void  ggml_opt_dataset_free (ggml_opt_dataset_t  dataset );
4547
4648    // get underlying tensors that store the data 
49+     GGML_API  int64_t               ggml_opt_dataset_ndata  (ggml_opt_dataset_t  dataset );
4750    GGML_API  struct  ggml_tensor  *  ggml_opt_dataset_data   (ggml_opt_dataset_t  dataset ); // shape = [ne_datapoint, ndata] 
4851    GGML_API  struct  ggml_tensor  *  ggml_opt_dataset_labels (ggml_opt_dataset_t  dataset ); // shape = [nd_label,     ndata] 
4952
@@ -56,6 +59,12 @@ extern "C" {
5659            struct  ggml_tensor  *  data_batch ,   // shape = [ne_datapoint, ndata_batch] 
5760            struct  ggml_tensor  *  labels_batch , // shape = [ne_label,     ndata_batch] 
5861            int64_t               ibatch );
62+     GGML_API  void  ggml_opt_dataset_get_batch_host (
63+             ggml_opt_dataset_t    dataset ,
64+             void                *  data_batch ,
65+             size_t                nb_data_batch ,
66+             void                *  labels_batch ,
67+             int64_t               ibatch );
5968
6069    // ====== Model / Context ====== 
6170
@@ -92,7 +101,8 @@ extern "C" {
92101        struct  ggml_context  *  ctx_compute ; // created in user code, holds non-static tensors 
93102
94103        // the forward graph is defined by inputs and outputs 
95-         // those tensors and all tensors inbetween are not intended to be reusable between multiple optimization contexts 
104+         // the outputs and all tensors between inputs and outputs that have not been statically allocated 
105+         //     are not intended to be reusable between multiple optimization contexts 
96106        struct  ggml_tensor  *  inputs ;
97107        struct  ggml_tensor  *  outputs ;
98108
@@ -107,7 +117,7 @@ extern "C" {
107117
108118    // get parameters for an optimization context with defaults set where possible 
109119    // parameters for which no sensible defaults exist are supplied as arguments to this function 
110-     GGML_API  ggml_opt_params  ggml_opt_default_params (
120+     GGML_API  struct   ggml_opt_params  ggml_opt_default_params (
111121            ggml_backend_sched_t       backend_sched ,
112122            struct  ggml_context      *  ctx_compute ,
113123            struct  ggml_tensor       *  inputs ,
@@ -144,6 +154,10 @@ extern "C" {
144154
145155    // ====== Computation ====== 
146156
157+     GGML_API  void  ggml_opt_set_forward_graph (
158+         ggml_opt_context_t  opt_ctx , struct  ggml_context  *  ctx_compute , struct  ggml_cgraph  *  gf ,
159+         struct  ggml_tensor  *  inputs , struct  ggml_tensor  *  outputs , bool  backward );
160+ 
147161    // do forward pass, increment result if not NULL 
148162    GGML_API  void  ggml_opt_forward (ggml_opt_context_t  opt_ctx , ggml_opt_result_t  result );
149163
@@ -200,9 +214,9 @@ extern "C" {
200214    // fit model defined by inputs and outputs to dataset 
201215    GGML_API  void  ggml_opt_fit (
202216            ggml_backend_sched_t             backend_sched ,  // backend scheduler for constructing the compute graphs 
203-             ggml_context                    *  ctx_compute ,    // context with temporarily allocated tensors to calculate the outputs 
204-             ggml_tensor                     *  inputs ,         // input tensor with shape [ne_datapoint, ndata_batch] 
205-             ggml_tensor                     *  outputs ,        // output tensor, must have shape [ne_label, ndata_batch] if labels are used 
217+             struct   ggml_context            *  ctx_compute ,    // context with temporarily allocated tensors to calculate the outputs 
218+             struct   ggml_tensor             *  inputs ,         // input tensor with shape [ne_datapoint, ndata_batch] 
219+             struct   ggml_tensor             *  outputs ,        // output tensor, must have shape [ne_label, ndata_batch] if labels are used 
206220            ggml_opt_dataset_t               dataset ,        // dataset with data and optionally also labels 
207221            enum  ggml_opt_loss_type          loss_type ,      // loss to minimize 
208222            ggml_opt_get_optimizer_params    get_opt_pars ,   // callback to get optimizer params, userdata is pointer to epoch (of type int64_t) 
0 commit comments