@@ -2105,46 +2105,105 @@ void llama_context::opt_epoch_iter(
21052105}
21062106
21072107void llama_context::opt_epoch (
2108- ggml_opt_dataset_t dataset,
2109- ggml_opt_result_t result_train,
2110- ggml_opt_result_t result_eval,
2111- int64_t idata_split,
2112- ggml_opt_epoch_callback callback_train,
2113- ggml_opt_epoch_callback callback_eval) {
2108+ ggml_opt_dataset_t dataset,
2109+ ggml_opt_result_t result_train,
2110+ ggml_opt_result_t result_eval,
2111+ int64_t idata_split,
2112+ ggml_opt_epoch_callback callback_train,
2113+ ggml_opt_epoch_callback callback_eval) {
21142114 const uint32_t n_ctx = this ->n_ctx ();
2115- const uint32_t n_batch = std::min (cparams.n_batch , n_ctx);
2115+ const uint32_t n_batch = std::min (cparams.n_batch , n_ctx);
21162116 const uint32_t n_ubatch = std::min (cparams.n_ubatch , n_batch);
2117- const int64_t ndata = ggml_opt_dataset_ndata (dataset);
2117+ const int64_t ndata = ggml_opt_dataset_ndata (dataset);
21182118
21192119 GGML_ASSERT (idata_split >= 0 );
21202120 GGML_ASSERT (idata_split <= ndata);
21212121
21222122 const uint32_t ubatch_per_ctx = n_ctx / n_ubatch;
21232123
21242124 struct llama_batch batch = llama_batch_init (n_batch, 0 , 1 );
2125- std::vector<llama_token> tokens (n_ctx);
2125+ std::vector<llama_token> tokens (n_ctx);
21262126 std::vector<llama_token> labels_sparse (n_ctx);
21272127
2128- int64_t idata = 0 ;
2128+ // Ensure batch is cleared
2129+ batch.n_tokens = 0 ;
21292130
21302131 int64_t t_loop_start = ggml_time_us ();
2131- int64_t ndata_in_loop = idata_split*ubatch_per_ctx;
2132- for (; idata < idata_split; ++idata) {
2132+ int64_t ndata_in_loop = idata_split * ubatch_per_ctx;
2133+
2134+ fprintf (stderr, " Starting training loop: idata_split = %ld, ndata = %ld, n_ctx = %u, n_batch = %u, n_ubatch = %u\n " ,
2135+ idata_split, ndata, n_ctx, n_batch, n_ubatch);
2136+
2137+ for (int64_t idata = 0 ; idata < idata_split; ++idata) {
21332138 constexpr bool train = true ;
2134- const int64_t idata_in_loop = idata*ubatch_per_ctx;
2139+ const int64_t idata_in_loop = idata * ubatch_per_ctx;
2140+
2141+ fprintf (stderr, " Training: idata = %ld, idata_in_loop = %ld\n " , idata, idata_in_loop);
21352142
2143+ // Clear vectors
2144+ std::fill (tokens.begin (), tokens.end (), 0 );
2145+ std::fill (labels_sparse.begin (), labels_sparse.end (), 0 );
2146+
2147+ // Retrieve batch with correct size
21362148 ggml_opt_dataset_get_batch_host (dataset, tokens.data (), n_ctx*sizeof (llama_token), labels_sparse.data (), idata);
2149+
2150+ fprintf (stderr, " Batch retrieved for training: idata = %ld\n " , idata);
2151+
2152+ // Populate batch
2153+ batch.n_tokens = 0 ;
2154+ for (uint32_t i = 0 ; i < 511 && batch.n_tokens < n_batch; ++i) {
2155+ batch.token [batch.n_tokens ] = tokens[i];
2156+ batch.pos [batch.n_tokens ] = i;
2157+ batch.seq_id [batch.n_tokens ] = 0 ;
2158+ batch.n_tokens ++;
2159+ }
2160+ // Add label (assuming single-token label)
2161+ if (batch.n_tokens < n_batch) {
2162+ batch.token [batch.n_tokens ] = labels_sparse[0 ];
2163+ batch.pos [batch.n_tokens ] = 511 ;
2164+ batch.seq_id [batch.n_tokens ] = 0 ;
2165+ batch.n_tokens ++;
2166+ }
2167+
21372168 opt_epoch_iter (dataset, result_train, tokens, labels_sparse, batch,
21382169 callback_train, train, idata_in_loop, ndata_in_loop, t_loop_start);
21392170 }
21402171
21412172 t_loop_start = ggml_time_us ();
2142- ndata_in_loop = (ndata - idata_split)*ubatch_per_ctx;
2143- for (; idata < ndata; ++idata) {
2173+ ndata_in_loop = (ndata - idata_split) * ubatch_per_ctx;
2174+
2175+ fprintf (stderr, " Starting validation loop: idata = %ld, ndata = %ld\n " , idata_split, ndata);
2176+
2177+ for (int64_t idata = idata_split; idata < ndata; ++idata) {
21442178 constexpr bool train = false ;
2145- const int64_t idata_in_loop = (idata - idata_split)*ubatch_per_ctx;
2179+ const int64_t idata_in_loop = (idata - idata_split) * ubatch_per_ctx;
2180+
2181+ fprintf (stderr, " Validation: idata = %ld, idata_in_loop = %ld\n " , idata, idata_in_loop);
2182+
2183+ // Clear vectors
2184+ std::fill (tokens.begin (), tokens.end (), 0 );
2185+ std::fill (labels_sparse.begin (), labels_sparse.end (), 0 );
21462186
21472187 ggml_opt_dataset_get_batch_host (dataset, tokens.data (), n_ctx*sizeof (llama_token), labels_sparse.data (), idata);
2188+
2189+ fprintf (stderr, " Batch retrieved for validation: idata = %ld\n " , idata);
2190+
2191+ // Populate batch
2192+ batch.n_tokens = 0 ;
2193+ for (uint32_t i = 0 ; i < 511 && batch.n_tokens < n_batch; ++i) {
2194+ batch.token [batch.n_tokens ] = tokens[i];
2195+ batch.pos [batch.n_tokens ] = i;
2196+ batch.seq_id [batch.n_tokens ] = 0 ;
2197+ batch.n_tokens ++;
2198+ }
2199+ // Add label
2200+ if (batch.n_tokens < n_batch) {
2201+ batch.token [batch.n_tokens ] = labels_sparse[0 ];
2202+ batch.pos [batch.n_tokens ] = 511 ;
2203+ batch.seq_id [batch.n_tokens ] = 0 ;
2204+ batch.n_tokens ++;
2205+ }
2206+
21482207 opt_epoch_iter (dataset, result_eval, tokens, labels_sparse, batch,
21492208 callback_eval, train, idata_in_loop, ndata_in_loop, t_loop_start);
21502209 }
0 commit comments