Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/gbnf-validator/gbnf-validator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ static bool llama_sample_grammar_string(struct llama_grammar * grammar, const st
size_t pos = 0;
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
auto prev_stacks = grammar->stacks;
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
if (grammar->stacks.empty()) {
error_pos = pos;
error_msg = "Unexpected character '" + unicode_cpt_to_utf8(*it) + "'";
Expand Down
16 changes: 10 additions & 6 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11912,12 +11912,13 @@ static void llama_grammar_advance_stack(
// be positioned at a character range (see `llama_grammar_advance_stack`), and
// produces the N possible stacks if the given char is accepted at those
// positions
std::vector<std::vector<const llama_grammar_element *>> llama_grammar_accept(
void llama_grammar_accept(
const std::vector<std::vector<llama_grammar_element>> & rules,
const std::vector<std::vector<const llama_grammar_element *>> & stacks,
const uint32_t chr) {
const uint32_t chr,
std::vector<std::vector<const llama_grammar_element *>> & new_stacks) {

std::vector<std::vector<const llama_grammar_element *>> new_stacks;
new_stacks.clear();

for (const auto & stack : stacks) {
if (stack.empty()) {
Expand All @@ -11936,8 +11937,6 @@ std::vector<std::vector<const llama_grammar_element *>> llama_grammar_accept(
llama_grammar_advance_stack(rules, new_stack, new_stacks);
}
}

return new_stacks;
}

static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates(
Expand All @@ -11951,6 +11950,7 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_
const std::vector<llama_grammar_candidate> & candidates) {

std::vector<llama_grammar_candidate> rejects;
rejects.reserve(candidates.size());

if (stack.empty()) {
for (const auto & tok : candidates) {
Expand All @@ -11964,6 +11964,8 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_
const llama_grammar_element * stack_pos = stack.back();

std::vector<llama_grammar_candidate> next_candidates;
next_candidates.reserve(candidates.size());

for (const auto & tok : candidates) {
if (*tok.code_points == 0) {
// reached end of full codepoints in token, reject iff it ended in a partial sequence
Expand Down Expand Up @@ -12771,8 +12773,10 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
// Note terminating 0 in decoded string
const auto decoded = decode_utf8(piece, grammar->partial_utf8);
const auto & code_points = decoded.first;
std::vector<std::vector<const llama_grammar_element *>> tmp_new_stacks;
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
llama_grammar_accept(grammar->rules, grammar->stacks, *it, tmp_new_stacks);
tmp_new_stacks.swap(grammar->stacks);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this better than saying grammar->stacks = tmp_new_stacks;? Because new_stacks is .clear()'d on 11921, it seems like we don't need to save its value here, and we could save a small step (?).

Mainly though, the recursive nature of the swap here was making my eyes cross when trying to follow exactly what this change was doing and how the contents of grammar->stacks and tmp_new_stacks were ping-ponging back and forth in this loop, so getting rid of the .swap() might make it a bit easier to read as well?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW, I tried making this change (into a local grammar-speedup4 branch), and it didn't significantly improve things, but it wasn't slower, and I think the code is a bit more readable:

Benchmark 1: ./main \
        -mu https://huggingface.co/TheBloke/phi-2-GGUF/resolve/main/phi-2.Q4_K_M.gguf \
        --grammar-file json_numbers.grammar \
        -p "List of 20 integers starting from 0" \
        --seed 12344 (branch = grammar-speedup4)
  Time (mean ± σ):     12.586 s ±  0.698 s    [User: 8.488 s, System: 1.799 s]
  Range (min … max):   12.012 s … 13.726 s    5 runs

Benchmark 2: ./main \
        -mu https://huggingface.co/TheBloke/phi-2-GGUF/resolve/main/phi-2.Q4_K_M.gguf \
        --grammar-file json_numbers.grammar \
        -p "List of 20 integers starting from 0" \
        --seed 12344 (branch = grammar-speedup3)
  Time (mean ± σ):     12.904 s ±  0.854 s    [User: 8.583 s, System: 1.954 s]
  Range (min … max):   11.846 s … 13.963 s    5 runs

Summary
  ./main \
        -mu https://huggingface.co/TheBloke/phi-2-GGUF/resolve/main/phi-2.Q4_K_M.gguf \
        --grammar-file json_numbers.grammar \
        -p "List of 20 integers starting from 0" \
        --seed 12344 (branch = grammar-speedup4) ran
    1.03 ± 0.09 times faster than ./main \
        -mu https://huggingface.co/TheBloke/phi-2-GGUF/resolve/main/phi-2.Q4_K_M.gguf \
        --grammar-file json_numbers.grammar \
        -p "List of 20 integers starting from 0" \
        --seed 12344 (branch = grammar-speedup3)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Heh, turns out my eyes-crossing swap wasn't even making things faster, removed it / looks simpler thanks.

}
grammar->partial_utf8 = decoded.second;
GGML_ASSERT(!grammar->stacks.empty());
Expand Down
5 changes: 3 additions & 2 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -1097,10 +1097,11 @@ const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal
struct llama_context * ctx
);

std::vector<std::vector<const llama_grammar_element *>> llama_grammar_accept(
void llama_grammar_accept(
const std::vector<std::vector<llama_grammar_element>> & rules,
const std::vector<std::vector<const llama_grammar_element *>> & stacks,
const uint32_t chr);
const uint32_t chr,
std::vector<std::vector<const llama_grammar_element *>> & new_stacks);

std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
const std::string & src,
Expand Down
6 changes: 3 additions & 3 deletions tests/test-grammar-integration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ number ::= [0-9]+)""";

for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
auto prev_stacks = grammar->stacks;
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
assert(!grammar->stacks.empty());
}

Expand Down Expand Up @@ -138,7 +138,7 @@ ws ::= [ \t\n\r]?)""";
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
++pos;
auto prev_stacks = grammar->stacks;
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);

// Expect that each code point will not cause the grammar to fail
if (grammar->stacks.empty()) {
Expand Down Expand Up @@ -173,7 +173,7 @@ ws ::= [ \t\n\r]?)""";

for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
auto prev_stacks = grammar->stacks;
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
if (grammar->stacks.empty()) {
parse_failed = true;
break;
Expand Down