@@ -286,17 +286,76 @@ Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens,
286286 };
287287
288288 if (embd) {
289+ if (n_tokens > 0 && embd > 0 && (size_t )n_tokens > SIZE_MAX / sizeof (float ) / (size_t )embd) {
290+ LOGe (" integer overflow in embd allocation" );
291+ delete batch;
292+ return 0 ;
293+ }
289294 batch->embd = (float *) malloc (sizeof (float ) * n_tokens * embd);
290295 } else {
296+ if (n_tokens > 0 && (size_t )n_tokens > SIZE_MAX / sizeof (llama_token)) {
297+ LOGe (" integer overflow in token allocation" );
298+ delete batch;
299+ return 0 ;
300+ }
291301 batch->token = (llama_token *) malloc (sizeof (llama_token) * n_tokens);
292302 }
293303
304+ if (n_tokens > 0 && (size_t )n_tokens > SIZE_MAX / sizeof (llama_pos)) {
305+ LOGe (" integer overflow in pos allocation" );
306+ if (embd) free (batch->embd ); else free (batch->token );
307+ delete batch;
308+ return 0 ;
309+ }
294310 batch->pos = (llama_pos *) malloc (sizeof (llama_pos) * n_tokens);
311+
312+ if (n_tokens > 0 && (size_t )n_tokens > SIZE_MAX / sizeof (int32_t )) {
313+ LOGe (" integer overflow in n_seq_id allocation" );
314+ free (batch->pos );
315+ if (embd) free (batch->embd ); else free (batch->token );
316+ delete batch;
317+ return 0 ;
318+ }
295319 batch->n_seq_id = (int32_t *) malloc (sizeof (int32_t ) * n_tokens);
320+
321+ if (n_tokens > 0 && (size_t )n_tokens > SIZE_MAX / sizeof (llama_seq_id *)) {
322+ LOGe (" integer overflow in seq_id allocation" );
323+ free (batch->n_seq_id );
324+ free (batch->pos );
325+ if (embd) free (batch->embd ); else free (batch->token );
326+ delete batch;
327+ return 0 ;
328+ }
296329 batch->seq_id = (llama_seq_id **) malloc (sizeof (llama_seq_id *) * n_tokens);
330+
297331 for (int i = 0 ; i < n_tokens; ++i) {
332+ if (n_seq_max > 0 && (size_t )n_seq_max > SIZE_MAX / sizeof (llama_seq_id)) {
333+ LOGe (" integer overflow in seq_id[%d] allocation" , i);
334+ for (int j = 0 ; j < i; ++j) {
335+ free (batch->seq_id [j]);
336+ }
337+ free (batch->seq_id );
338+ free (batch->n_seq_id );
339+ free (batch->pos );
340+ if (embd) free (batch->embd ); else free (batch->token );
341+ delete batch;
342+ return 0 ;
343+ }
298344 batch->seq_id [i] = (llama_seq_id *) malloc (sizeof (llama_seq_id) * n_seq_max);
299345 }
346+
347+ if (n_tokens > 0 && (size_t )n_tokens > SIZE_MAX / sizeof (int8_t )) {
348+ LOGe (" integer overflow in logits allocation" );
349+ for (int i = 0 ; i < n_tokens; ++i) {
350+ free (batch->seq_id [i]);
351+ }
352+ free (batch->seq_id );
353+ free (batch->n_seq_id );
354+ free (batch->pos );
355+ if (embd) free (batch->embd ); else free (batch->token );
356+ delete batch;
357+ return 0 ;
358+ }
300359 batch->logits = (int8_t *) malloc (sizeof (int8_t ) * n_tokens);
301360
302361 return reinterpret_cast <jlong>(batch);
0 commit comments