Skip to content

Commit 6b065fc

Browse files
fix: add integer overflow checks for malloc in llama-android.cpp
Co-Authored-By: Jake Cosme <[email protected]>
1 parent 4852a8a commit 6b065fc

File tree

1 file changed

+59
-0
lines changed

1 file changed

+59
-0
lines changed

examples/llama.android/llama/src/main/cpp/llama-android.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,17 +286,76 @@ Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens,
286286
};
287287

288288
if (embd) {
289+
if (n_tokens > 0 && embd > 0 && (size_t)n_tokens > SIZE_MAX / sizeof(float) / (size_t)embd) {
290+
LOGe("integer overflow in embd allocation");
291+
delete batch;
292+
return 0;
293+
}
289294
batch->embd = (float *) malloc(sizeof(float) * n_tokens * embd);
290295
} else {
296+
if (n_tokens > 0 && (size_t)n_tokens > SIZE_MAX / sizeof(llama_token)) {
297+
LOGe("integer overflow in token allocation");
298+
delete batch;
299+
return 0;
300+
}
291301
batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens);
292302
}
293303

304+
if (n_tokens > 0 && (size_t)n_tokens > SIZE_MAX / sizeof(llama_pos)) {
305+
LOGe("integer overflow in pos allocation");
306+
if (embd) free(batch->embd); else free(batch->token);
307+
delete batch;
308+
return 0;
309+
}
294310
batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens);
311+
312+
if (n_tokens > 0 && (size_t)n_tokens > SIZE_MAX / sizeof(int32_t)) {
313+
LOGe("integer overflow in n_seq_id allocation");
314+
free(batch->pos);
315+
if (embd) free(batch->embd); else free(batch->token);
316+
delete batch;
317+
return 0;
318+
}
295319
batch->n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens);
320+
321+
if (n_tokens > 0 && (size_t)n_tokens > SIZE_MAX / sizeof(llama_seq_id *)) {
322+
LOGe("integer overflow in seq_id allocation");
323+
free(batch->n_seq_id);
324+
free(batch->pos);
325+
if (embd) free(batch->embd); else free(batch->token);
326+
delete batch;
327+
return 0;
328+
}
296329
batch->seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens);
330+
297331
for (int i = 0; i < n_tokens; ++i) {
332+
if (n_seq_max > 0 && (size_t)n_seq_max > SIZE_MAX / sizeof(llama_seq_id)) {
333+
LOGe("integer overflow in seq_id[%d] allocation", i);
334+
for (int j = 0; j < i; ++j) {
335+
free(batch->seq_id[j]);
336+
}
337+
free(batch->seq_id);
338+
free(batch->n_seq_id);
339+
free(batch->pos);
340+
if (embd) free(batch->embd); else free(batch->token);
341+
delete batch;
342+
return 0;
343+
}
298344
batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
299345
}
346+
347+
if (n_tokens > 0 && (size_t)n_tokens > SIZE_MAX / sizeof(int8_t)) {
348+
LOGe("integer overflow in logits allocation");
349+
for (int i = 0; i < n_tokens; ++i) {
350+
free(batch->seq_id[i]);
351+
}
352+
free(batch->seq_id);
353+
free(batch->n_seq_id);
354+
free(batch->pos);
355+
if (embd) free(batch->embd); else free(batch->token);
356+
delete batch;
357+
return 0;
358+
}
300359
batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
301360

302361
return reinterpret_cast<jlong>(batch);

0 commit comments

Comments
 (0)