Skip to content

Commit d8485ad

Browse files
committed
Fix DOS index bug
1 parent 80a02aa commit d8485ad

File tree

6 files changed

+38
-0
lines changed

6 files changed

+38
-0
lines changed

common/common.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1748,6 +1748,12 @@ std::string common_token_to_piece(const struct llama_vocab * vocab, llama_token
17481748
return piece;
17491749
}
17501750

1751+
bool can_be_detokenized(const struct llama_context * ctx, const std::vector<llama_token> & tokens) {
1752+
const llama_model * model = llama_get_model(ctx);
1753+
const llama_vocab * vocab = llama_model_get_vocab(model);
1754+
return can_detokenize(vocab, tokens.data(), (int32_t) tokens.size());
1755+
}
1756+
17511757
std::string common_detokenize(const struct llama_context * ctx, const std::vector<llama_token> & tokens, bool special) {
17521758
const llama_model * model = llama_get_model(ctx);
17531759
const llama_vocab * vocab = llama_model_get_vocab(model);

common/common.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,10 @@ size_t common_lcs(const llama_tokens & a, const llama_tokens & b);
593593
// Vocab utils
594594
//
595595

596+
bool can_be_detokenized(
597+
const struct llama_context * ctx,
598+
const std::vector<llama_token> & tokens);
599+
596600
// tokenizes a string into a vector of tokens
597601
// should work similar to Python's `tokenizer.encode`
598602
std::vector<llama_token> common_tokenize(

examples/server/server.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2054,6 +2054,11 @@ struct server_context {
20542054
slot.lora = task.params.lora;
20552055
}
20562056

2057+
bool can_detokenize = can_be_detokenized(ctx, slot.prompt_tokens);
2058+
if (!can_detokenize) {
2059+
send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST);
2060+
return false;
2061+
}
20572062
SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());
20582063

20592064
if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {

include/llama.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,6 +1021,10 @@ extern "C" {
10211021
int32_t lstrip,
10221022
bool special);
10231023

1024+
/// @details Determine whether the provided tokens are in the vocabulary.
1025+
/// @return Returns true if every token provided is in the vocabulary.
1026+
LLAMA_API bool can_detokenize(const struct llama_vocab * vocab, const llama_token * tokens, int32_t n_tokens);
1027+
10241028
/// @details Convert the provided tokens into text (inverse of llama_tokenize()).
10251029
/// @param text The char pointer must be large enough to hold the resulting text.
10261030
/// @return Returns the number of chars/bytes on success, no more than text_len_max.

src/llama-vocab.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1279,6 +1279,7 @@ struct llama_vocab::impl {
12791279

12801280
std::string type_name() const;
12811281

1282+
bool is_valid (llama_token id) const;
12821283
bool is_normal (llama_token id) const;
12831284
bool is_unknown (llama_token id) const;
12841285
bool is_control (llama_token id) const;
@@ -2068,6 +2069,11 @@ std::string llama_vocab::impl::type_name() const{
20682069
}
20692070
}
20702071

2072+
bool llama_vocab::impl::is_valid(llama_token id) const {
2073+
GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
2074+
return 0 <= id && id < (int32_t) id_to_token.size();
2075+
}
2076+
20712077
bool llama_vocab::impl::is_normal(llama_token id) const {
20722078
GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
20732079
return id_to_token[id].attr & LLAMA_TOKEN_ATTR_NORMAL;
@@ -2760,6 +2766,10 @@ std::string llama_vocab::type_name() const{
27602766
return pimpl->type_name();
27612767
}
27622768

2769+
bool llama_vocab::is_valid(llama_token id) const {
2770+
return pimpl->is_valid(id);
2771+
}
2772+
27632773
bool llama_vocab::is_normal(llama_token id) const {
27642774
return pimpl->is_normal(id);
27652775
}
@@ -3262,3 +3272,11 @@ int32_t llama_detokenize(
32623272
return vocab->detokenize(tokens, n_tokens, text, text_len_max, remove_special, unparse_special);
32633273
}
32643274

3275+
bool can_detokenize(const struct llama_vocab * vocab, const llama_token * tokens, int32_t n_tokens) {
3276+
for (int32_t i = 0; i < n_tokens; ++i) {
3277+
if (!vocab->is_valid(tokens[i])) {
3278+
return false;
3279+
}
3280+
}
3281+
return true;
3282+
}

src/llama-vocab.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ struct llama_vocab {
2929

3030
std::string type_name() const;
3131

32+
bool is_valid (llama_token id) const;
3233
bool is_normal (llama_token id) const;
3334
bool is_unknown (llama_token id) const;
3435
bool is_control (llama_token id) const;

0 commit comments

Comments
 (0)