@@ -721,15 +721,17 @@ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch,
721721 return res;
722722}
723723
724- int llama_context::encode (llama_batch & inp_batch ) {
725- if (inp_batch .n_tokens == 0 ) {
724+ int llama_context::encode (const llama_batch & batch_inp ) {
725+ if (batch_inp .n_tokens == 0 ) {
726726 LLAMA_LOG_ERROR (" %s: n_tokens == 0\n " , __func__);
727727 return -1 ;
728728 }
729729
730730 // temporary allocate memory for the input batch if needed
731731 // note: during encode, we always pass the full sequence starting from pos = 0
732- batch_allocr->init (inp_batch, inp_batch.pos ? -1 : 0 );
732+ if (!batch_allocr->init (batch_inp, model.vocab , batch_inp.pos ? -1 : 0 )) {
733+ return -1 ;
734+ }
733735
734736 const llama_batch & batch = batch_allocr->get_batch ();
735737
@@ -739,21 +741,6 @@ int llama_context::encode(llama_batch & inp_batch) {
739741
740742 GGML_ASSERT ((!batch.token && batch.embd ) || (batch.token && !batch.embd )); // NOLINT
741743
742- // TODO: move the validation to the llama_batch_allocr
743- if (batch.token ) {
744- for (uint32_t i = 0 ; i < n_tokens; ++i) {
745- if (batch.token [i] < 0 || (uint32_t ) batch.token [i] >= model.vocab .n_tokens ()) {
746- LLAMA_LOG_ERROR (" %s: invalid token[%d] = %d\n " , __func__, i, batch.token [i]);
747- return -1 ;
748- }
749-
750- if (batch.seq_id && (batch.seq_id [i][0 ] < 0 || batch.seq_id [i][0 ] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
751- LLAMA_LOG_ERROR (" %s: invalid seq_id[%d] = %d > %d\n " , __func__, i, batch.seq_id [i][0 ], LLAMA_MAX_PARALLEL_SEQUENCES);
752- throw -1 ;
753- }
754- }
755- }
756-
757744 // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
758745 GGML_ASSERT (cparams.n_ubatch >= (uint32_t ) n_tokens && " encoder requires n_ubatch >= n_tokens" );
759746
@@ -897,26 +884,28 @@ int llama_context::encode(llama_batch & inp_batch) {
897884 return 0 ;
898885}
899886
900- int llama_context::decode (llama_batch & inp_batch ) {
887+ int llama_context::decode (const llama_batch & batch_inp ) {
901888 if (!memory) {
902889 LLAMA_LOG_DEBUG (" %s: cannot decode batches with this context (calling encode() instead)\n " , __func__);
903- return encode (inp_batch );
890+ return encode (batch_inp );
904891 }
905892
906- if (inp_batch .n_tokens == 0 ) {
893+ if (batch_inp .n_tokens == 0 ) {
907894 LLAMA_LOG_ERROR (" %s: n_tokens == 0\n " , __func__);
908895 return -1 ;
909896 }
910897
911- if (!inp_batch .pos ) {
912- if (inp_batch .seq_id ) {
898+ if (!batch_inp .pos ) {
899+ if (batch_inp .seq_id ) {
913900 LLAMA_LOG_ERROR (" %s: pos == NULL, but seq_id != NULL\n " , __func__);
914901 return -1 ;
915902 }
916903 }
917904
918905 // temporary allocate memory for the input batch if needed
919- batch_allocr->init (inp_batch, inp_batch.pos ? -1 : memory->seq_pos_max (0 ) + 1 );
906+ if (!batch_allocr->init (batch_inp, model.vocab , batch_inp.pos ? -1 : memory->seq_pos_max (0 ) + 1 )) {
907+ return -1 ;
908+ }
920909
921910 const llama_batch & batch = batch_allocr->get_batch ();
922911
@@ -930,21 +919,6 @@ int llama_context::decode(llama_batch & inp_batch) {
930919
931920 GGML_ASSERT ((!batch.token && batch.embd ) || (batch.token && !batch.embd )); // NOLINT
932921
933- // TODO: move the validation to the llama_batch_allocr
934- if (batch.token ) {
935- for (uint32_t i = 0 ; i < n_tokens_all; ++i) {
936- if (batch.token [i] < 0 || (uint32_t ) batch.token [i] >= model.vocab .n_tokens ()) {
937- LLAMA_LOG_ERROR (" %s: invalid token[%d] = %d\n " , __func__, i, batch.token [i]);
938- return -1 ;
939- }
940-
941- if (batch.seq_id && (batch.seq_id [i][0 ] < 0 || batch.seq_id [i][0 ] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
942- LLAMA_LOG_ERROR (" %s: invalid seq_id[%d] = %d >= %d\n " , __func__, i, batch.seq_id [i][0 ], LLAMA_MAX_PARALLEL_SEQUENCES);
943- return -1 ;
944- }
945- }
946- }
947-
948922 // this indicates we are doing pooled embedding
949923 const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
950924
0 commit comments