@@ -693,12 +693,18 @@ int llama_context::encode(llama_batch & inp_batch) {
693693
694694 GGML_ASSERT ((!batch.token && batch.embd ) || (batch.token && !batch.embd )); // NOLINT
695695
696+ // TODO: move the validation to the llama_batch_allocr
696697 if (batch.token ) {
697698 for (int32_t i = 0 ; i < n_tokens; ++i) {
698699 if (batch.token [i] < 0 || (uint32_t ) batch.token [i] >= model.vocab .n_tokens ()) {
699700 LLAMA_LOG_ERROR (" %s: invalid token[%d] = %d\n " , __func__, i, batch.token [i]);
700701 return -1 ;
701702 }
703+
704+ if (batch.seq_id && (batch.seq_id [i][0 ] < 0 || batch.seq_id [i][0 ] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
705+ LLAMA_LOG_ERROR (" %s: invalid seq_id[%d] = %d > %d\n " , __func__, i, batch.seq_id [i][0 ], LLAMA_MAX_PARALLEL_SEQUENCES);
706+ throw -1 ;
707+ }
702708 }
703709 }
704710
@@ -887,11 +893,17 @@ int llama_context::decode(llama_batch & inp_batch) {
887893
888894 GGML_ASSERT ((!batch.token && batch.embd ) || (batch.token && !batch.embd )); // NOLINT
889895
896+ // TODO: move the validation to the llama_batch_allocr
890897 if (batch.token ) {
891898 for (int64_t i = 0 ; i < n_tokens_all; ++i) {
892899 if (batch.token [i] < 0 || (uint32_t ) batch.token [i] >= model.vocab .n_tokens ()) {
893900 LLAMA_LOG_ERROR (" %s: invalid token[%" PRId64 " ] = %d\n " , __func__, i, batch.token [i]);
894- throw std::runtime_error (" invalid token" );
901+ return -1 ;
902+ }
903+
904+ if (batch.seq_id && (batch.seq_id [i][0 ] < 0 || batch.seq_id [i][0 ] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
905+ LLAMA_LOG_ERROR (" %s: invalid seq_id[%" PRId64 " ] = %d >= %d\n " , __func__, i, batch.seq_id [i][0 ], LLAMA_MAX_PARALLEL_SEQUENCES);
906+ return -1 ;
895907 }
896908 }
897909 }
0 commit comments