Skip to content

Commit 495b1f5

Browse files
fix: add integer overflow checks for memory allocation in llama-batch.cpp
- Added overflow checks before malloc calls - Prevents integer overflow in batch initialization - Properly handles error cases with cleanup Addresses integer overflow vulnerabilities (CWE-190) Co-Authored-By: Jake Cosme <[email protected]>
1 parent a1be60e commit 495b1f5

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

src/llama-batch.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -842,19 +842,54 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
842842
};
843843

844844
if (embd) {
845+
if (n_tokens_alloc > 0 && embd > 0 && n_tokens_alloc > SIZE_MAX / (sizeof(float) * embd)) {
846+
LLAMA_LOG_ERROR("%s: integer overflow in memory allocation\n", __func__);
847+
return batch;
848+
}
845849
batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
846850
} else {
851+
if (n_tokens_alloc > SIZE_MAX / sizeof(llama_token)) {
852+
LLAMA_LOG_ERROR("%s: integer overflow in memory allocation\n", __func__);
853+
return batch;
854+
}
847855
batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
848856
}
849857

858+
if (n_tokens_alloc > SIZE_MAX / sizeof(llama_pos)) {
859+
LLAMA_LOG_ERROR("%s: integer overflow in memory allocation\n", __func__);
860+
llama_batch_free(batch);
861+
return batch;
862+
}
850863
batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc);
864+
865+
if (n_tokens_alloc > SIZE_MAX / sizeof(int32_t)) {
866+
LLAMA_LOG_ERROR("%s: integer overflow in memory allocation\n", __func__);
867+
llama_batch_free(batch);
868+
return batch;
869+
}
851870
batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc);
871+
872+
if (n_tokens_alloc + 1 > SIZE_MAX / sizeof(llama_seq_id *)) {
873+
LLAMA_LOG_ERROR("%s: integer overflow in memory allocation\n", __func__);
874+
llama_batch_free(batch);
875+
return batch;
876+
}
852877
batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1));
853878
for (int i = 0; i < n_tokens_alloc; ++i) {
879+
if (n_seq_max > SIZE_MAX / sizeof(llama_seq_id)) {
880+
LLAMA_LOG_ERROR("%s: integer overflow in memory allocation\n", __func__);
881+
llama_batch_free(batch);
882+
return batch;
883+
}
854884
batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
855885
}
856886
batch.seq_id[n_tokens_alloc] = nullptr;
857887

888+
if (n_tokens_alloc > SIZE_MAX / sizeof(int8_t)) {
889+
LLAMA_LOG_ERROR("%s: integer overflow in memory allocation\n", __func__);
890+
llama_batch_free(batch);
891+
return batch;
892+
}
858893
batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc);
859894

860895
return batch;

0 commit comments

Comments
 (0)