-
Notifications
You must be signed in to change notification settings - Fork 13.4k
grammars
: cache decoded token codepoints for faster sampling
#6811
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
98f33ba
f608415
cbc7580
1fb300c
24769f9
05efa34
9f13623
d41f314
49f0faa
d3425f5
c70037f
0148661
80736c5
939e143
60745ac
40655e8
2477d84
c208932
6a9b626
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12727,6 +12727,10 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_ | |
} | ||
} | ||
|
||
if (next_candidates.empty()) { | ||
return rejects; | ||
|
||
} | ||
|
||
const auto * stack_pos_after = llama_grammar_match_char(stack_pos, 0).second; | ||
|
||
// update top of stack to next element, if any | ||
|
@@ -12804,26 +12808,32 @@ struct llama_grammar * llama_grammar_init( | |
} | ||
} while (true); | ||
|
||
return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} }; | ||
return new llama_grammar{ std::move(vec_rules), std::move(stacks), {}, {}, {} }; | ||
} | ||
|
||
void llama_grammar_free(struct llama_grammar * grammar) { | ||
delete grammar; | ||
} | ||
|
||
struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) { | ||
llama_grammar * result = new llama_grammar{ grammar->rules, grammar->stacks, grammar->partial_utf8 }; | ||
llama_grammar * result = new llama_grammar{ grammar->rules, grammar->stacks, grammar->partial_utf8, grammar->token_pieces, grammar->token_codepoints }; | ||
|
||
std::unordered_map<const llama_grammar_element *, const llama_grammar_element *> element_map; | ||
element_map.reserve(std::accumulate( | ||
grammar->rules.begin(), grammar->rules.end(), 0, | ||
[](size_t acc, const std::vector<llama_grammar_element> & rule) { | ||
return acc + rule.size(); | ||
})); | ||
for (size_t ir = 0; ir < grammar->rules.size(); ir++) { | ||
for (size_t ie = 0; ie < grammar->rules[ir].size(); ie++) { | ||
element_map[&grammar->rules[ir][ie]] = &result->rules[ir][ie]; | ||
} | ||
} | ||
|
||
// redirect elements in stacks to point to new rules | ||
for (size_t is = 0; is < result->stacks.size(); is++) { | ||
for (size_t ie = 0; ie < result->stacks[is].size(); ie++) { | ||
for (size_t ir0 = 0; ir0 < grammar->rules.size(); ir0++) { | ||
for (size_t ir1 = 0; ir1 < grammar->rules[ir0].size(); ir1++) { | ||
if (grammar->stacks[is][ie] == &grammar->rules[ir0][ir1]) { | ||
result->stacks[is][ie] = &result->rules[ir0][ir1]; | ||
} | ||
} | ||
} | ||
result->stacks[is][ie] = element_map.at(grammar->stacks[is][ie]); | ||
} | ||
} | ||
|
||
|
@@ -13293,7 +13303,7 @@ void llama_sample_repetition_penalties( | |
} | ||
} | ||
|
||
void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar) { | ||
void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, struct llama_grammar * grammar) { | ||
GGML_ASSERT(ctx); | ||
const int64_t t_start_sample_us = ggml_time_us(); | ||
|
||
|
@@ -13305,21 +13315,36 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c | |
} | ||
} | ||
|
||
if (grammar->token_codepoints.empty()) { | ||
auto n_vocab = llama_n_vocab(llama_get_model(ctx)); | ||
grammar->token_codepoints.resize(n_vocab); | ||
grammar->token_pieces.resize(n_vocab); | ||
for (llama_token id = 0; id < n_vocab; ++id) { | ||
const std::string piece = llama_token_to_piece(ctx, id, false); | ||
grammar->token_pieces[id] = piece; | ||
grammar->token_codepoints[id] = decode_utf8(piece, {0, 0}); | ||
} | ||
} | ||
|
||
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded; | ||
candidates_decoded.reserve(candidates->size); | ||
if (grammar->partial_utf8.n_remain > 0) { | ||
candidates_decoded.reserve(candidates->size); | ||
HanClinto marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
std::vector<llama_grammar_candidate> candidates_grammar; | ||
candidates_grammar.reserve(candidates->size); | ||
|
||
for (size_t i = 0; i < candidates->size; ++i) { | ||
const llama_token id = candidates->data[i].id; | ||
const std::string piece = llama_token_to_piece(ctx, id, false); | ||
|
||
const auto & piece = grammar->token_pieces[id]; | ||
if (llama_token_is_eog(&ctx->model, id)) { | ||
if (!allow_eog) { | ||
candidates->data[i].logit = -INFINITY; | ||
} | ||
} else if (piece.empty() || piece[0] == 0) { | ||
candidates->data[i].logit = -INFINITY; | ||
} else if (grammar->partial_utf8.n_remain == 0){ | ||
const auto & decoded = grammar->token_codepoints.at(id); | ||
candidates_grammar.push_back({ i, decoded.first.data(), decoded.second }); | ||
} else { | ||
candidates_decoded.push_back(decode_utf8(piece, grammar->partial_utf8)); | ||
candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second }); | ||
|
@@ -13513,10 +13538,12 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar | |
GGML_ASSERT(false); | ||
} | ||
|
||
const std::string piece = llama_token_to_piece(ctx, token, false); | ||
const auto & piece = grammar->token_pieces.at(token); | ||
|
||
// Note terminating 0 in decoded string | ||
const auto decoded = decode_utf8(piece, grammar->partial_utf8); | ||
const auto decoded = grammar->partial_utf8.n_remain == 0 | ||
? grammar->token_codepoints[token] | ||
: decode_utf8(piece, grammar->partial_utf8); | ||
const auto & code_points = decoded.first; | ||
std::vector<std::vector<const llama_grammar_element *>> tmp_new_stacks; | ||
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { | ||
|
Uh oh!
There was an error while loading. Please reload this page.