@@ -842,19 +842,54 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
842842 };
843843
844844 if (embd) {
845+ if (n_tokens_alloc > 0 && embd > 0 && n_tokens_alloc > SIZE_MAX / (sizeof (float ) * embd)) {
846+ LLAMA_LOG_ERROR (" %s: integer overflow in memory allocation\n " , __func__);
847+ return batch;
848+ }
845849 batch.embd = (float *) malloc (sizeof (float ) * n_tokens_alloc * embd);
846850 } else {
851+ if (n_tokens_alloc > SIZE_MAX / sizeof (llama_token)) {
852+ LLAMA_LOG_ERROR (" %s: integer overflow in memory allocation\n " , __func__);
853+ return batch;
854+ }
847855 batch.token = (llama_token *) malloc (sizeof (llama_token) * n_tokens_alloc);
848856 }
849857
858+ if (n_tokens_alloc > SIZE_MAX / sizeof (llama_pos)) {
859+ LLAMA_LOG_ERROR (" %s: integer overflow in memory allocation\n " , __func__);
860+ llama_batch_free (batch);
861+ return batch;
862+ }
850863 batch.pos = (llama_pos *) malloc (sizeof (llama_pos) * n_tokens_alloc);
864+
865+ if (n_tokens_alloc > SIZE_MAX / sizeof (int32_t )) {
866+ LLAMA_LOG_ERROR (" %s: integer overflow in memory allocation\n " , __func__);
867+ llama_batch_free (batch);
868+ return batch;
869+ }
851870 batch.n_seq_id = (int32_t *) malloc (sizeof (int32_t ) * n_tokens_alloc);
871+
872+ if (n_tokens_alloc + 1 > SIZE_MAX / sizeof (llama_seq_id *)) {
873+ LLAMA_LOG_ERROR (" %s: integer overflow in memory allocation\n " , __func__);
874+ llama_batch_free (batch);
875+ return batch;
876+ }
852877 batch.seq_id = (llama_seq_id **) malloc (sizeof (llama_seq_id *) * (n_tokens_alloc + 1 ));
853878 for (int i = 0 ; i < n_tokens_alloc; ++i) {
879+ if (n_seq_max > SIZE_MAX / sizeof (llama_seq_id)) {
880+ LLAMA_LOG_ERROR (" %s: integer overflow in memory allocation\n " , __func__);
881+ llama_batch_free (batch);
882+ return batch;
883+ }
854884 batch.seq_id [i] = (llama_seq_id *) malloc (sizeof (llama_seq_id) * n_seq_max);
855885 }
856886 batch.seq_id [n_tokens_alloc] = nullptr ;
857887
888+ if (n_tokens_alloc > SIZE_MAX / sizeof (int8_t )) {
889+ LLAMA_LOG_ERROR (" %s: integer overflow in memory allocation\n " , __func__);
890+ llama_batch_free (batch);
891+ return batch;
892+ }
858893 batch.logits = (int8_t *) malloc (sizeof (int8_t ) * n_tokens_alloc);
859894
860895 return batch;
0 commit comments