Skip to content

Commit 974f045

Browse files
committed
Fix DoS / integer overflow
1 parent 07b0e7a commit 974f045

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

src/llama-grammar.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include <cmath>
88
#include <algorithm>
9+
#include <optional>
910
#include <stdexcept>
1011

1112
//
@@ -478,10 +479,10 @@ const char * llama_grammar_parser::parse_sequence(
478479
throw std::runtime_error(std::string("expecting an int at ") + pos);
479480
}
480481
const char * int_end = parse_int(pos);
481-
int min_times = std::stoul(std::string(pos, int_end - pos));
482+
unsigned long min_times = std::stoul(std::string(pos, int_end - pos));
482483
pos = parse_space(int_end, is_nested);
483484

484-
int max_times = -1;
485+
std::optional<unsigned long> max_times = std::optional<unsigned long>();
485486

486487
if (*pos == '}') {
487488
max_times = min_times;
@@ -502,7 +503,10 @@ const char * llama_grammar_parser::parse_sequence(
502503
} else {
503504
throw std::runtime_error(std::string("expecting ',' at ") + pos);
504505
}
505-
handle_repetitions(min_times, max_times);
506+
if (min_times > MAX_REPETITION_THRESHOLD || (max_times.has_value() && max_times.value() > MAX_REPETITION_THRESHOLD)) {
507+
throw std::runtime_error(std::string("number of repetitions exceeds sane defaults, please reduce the number of repetitions"));
508+
}
509+
handle_repetitions(min_times, max_times.value_or(-1));
506510
} else {
507511
break;
508512
}

src/llama-grammar.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
#include <string>
88
#include <vector>
99

10+
#define MAX_REPETITION_THRESHOLD 5000
11+
1012
struct llama_vocab;
1113

1214
// grammar element type

0 commit comments

Comments
 (0)