@@ -1954,6 +1954,17 @@ void llama_context::opt_epoch_iter(
19541954 // }
19551955 llama_ubatch ubatch = kv_self->ubatch_next (sbatch, cparams.n_ubatch , embd_pooled);
19561956
1957+ n_outputs = ubatch.n_tokens ;
1958+
1959+ printf (" ubatch.n_tokens = %d\n " , ubatch.n_tokens );
1960+
1961+ // TODO: not sure if this is needed
1962+ if (!kv_self->find_slot (ubatch)) {
1963+ LLAMA_LOG_WARN (" %s: failed to find KV cache slot for ubatch of size %d\n " , __func__, ubatch.n_tokens );
1964+
1965+ GGML_ABORT (" TODO: handle this error" );
1966+ }
1967+
19571968 auto * gf = graph_init ();
19581969 auto res = graph_build (ctx_compute.get (), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
19591970
@@ -1969,7 +1980,7 @@ void llama_context::opt_epoch_iter(
19691980 };
19701981 ctx_compute_opt = ggml_init (params);
19711982 }
1972- ggml_opt_prepare_alloc (opt_ctx, ctx_compute_opt, gf, res->get_tokens (), ggml_graph_node (gf, - 1 ));
1983+ ggml_opt_prepare_alloc (opt_ctx, ctx_compute_opt, gf, res->get_tokens (), res-> get_logits ( ));
19731984 ggml_opt_alloc (opt_ctx, train);
19741985 // llama_set_inputs(*lctx, ubatch);
19751986 res->set_inputs (&ubatch);
0 commit comments