Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/gbnf-validator/gbnf-validator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@ static bool llama_grammar_validate(struct llama_grammar * grammar, const std::st

const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
llama_grammar_stacks_cache & stacks_cache = llama_grammar_get_stacks_cache(grammar);

size_t pos = 0;
for (const auto & cpt : cpts) {
const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy

llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur);
llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur, stacks_cache);

if (stacks_cur.empty()) {
error_pos = pos;
Expand Down
116 changes: 109 additions & 7 deletions src/llama-grammar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,101 @@ static bool llama_grammar_match_partial_char(
return !is_positive_char;
}

// transforms a grammar pushdown stack into N possible stacks, all ending
// at a character range (terminal element)
// additionally memorizes the stack to its possible stacks by mapping
// < llama_grammar_stack, llama_grammar_stacks >

static void llama_grammar_advance_stack_memo(
const llama_grammar_rules & rules,
const llama_grammar_stack & stack,
llama_grammar_stacks & new_stacks,
llama_grammar_stacks_cache & stacks_cache);

static void llama_grammar_advance_stack_memo_impl(
const llama_grammar_rules & rules,
const llama_grammar_stack & stack,
llama_grammar_stacks & new_stacks,
llama_grammar_stacks_cache & stacks_cache) {
if (stack.empty()) {
if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
new_stacks.emplace_back(stack);
}
return;
}

const llama_grammar_element * pos = stack.back();

switch (pos->type) {
case LLAMA_GRETYPE_RULE_REF: {
const size_t rule_id = static_cast<size_t>(pos->value);
const llama_grammar_element * subpos = rules[rule_id].data();
do {
// init new stack without the top (pos)
llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
if (!llama_grammar_is_end_of_sequence(pos + 1)) {
// if this rule ref is followed by another element, add that to stack
new_stack.push_back(pos + 1);
}
if (!llama_grammar_is_end_of_sequence(subpos)) {
// if alternate is nonempty, add to stack
new_stack.push_back(subpos);
}
llama_grammar_advance_stack_memo(rules, new_stack, new_stacks, stacks_cache);
while (!llama_grammar_is_end_of_sequence(subpos)) {
// scan to end of alternate def
subpos++;
}
if (subpos->type == LLAMA_GRETYPE_ALT) {
// there's another alternate def of this rule to process
subpos++;
} else {
break;
}
} while (true);
break;
}
case LLAMA_GRETYPE_CHAR:
case LLAMA_GRETYPE_CHAR_NOT:
case LLAMA_GRETYPE_CHAR_ANY:
if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
// only add the stack if it's not a duplicate of one we already have
new_stacks.emplace_back(stack);
}
break;
default:
// end of alternate (LLAMA_GRETYPE_END, LLAMA_GRETYPE_ALT) or middle of char range
// (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
// those
GGML_ABORT("fatal error");
}
}

static void llama_grammar_advance_stack_memo(
const llama_grammar_rules & rules,
const llama_grammar_stack & stack,
llama_grammar_stacks & new_stacks,
llama_grammar_stacks_cache & stacks_cache) {

llama_grammar_stacks advanced_stacks;
// Look if stack is already in memory
auto it = stacks_cache.find(stack);
if (it != stacks_cache.end()) {
advanced_stacks = it->second;
} else {
// Advance stacks with memorization
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Advance stacks with memorization
// Advance stacks with memorization

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Advance stacks with memorization
// Advance stacks with memoization

llama_grammar_advance_stack_memo_impl(rules, stack, advanced_stacks, stacks_cache);
stacks_cache.insert(make_pair(stack, advanced_stacks));
}
// Add the advanced stacks to new_stacks avoiding duplicates
for (const auto & new_stack : advanced_stacks) {
if (std::find(new_stacks.begin(), new_stacks.end(), new_stack) == new_stacks.end()) {
new_stacks.emplace_back(new_stack);
}
}

}

// transforms a grammar pushdown stack into N possible stacks, all ending
// at a character range (terminal element)
static void llama_grammar_advance_stack(
Expand Down Expand Up @@ -822,11 +917,16 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar)
return grammar->stacks;
}

llama_grammar_stacks_cache & llama_grammar_get_stacks_cache(struct llama_grammar * grammar) {
return grammar->stacks_cache;
}

void llama_grammar_accept(
const llama_grammar_rules & rules,
const llama_grammar_stacks & stacks,
const uint32_t chr,
llama_grammar_stacks & stacks_new) {
llama_grammar_stacks & stacks_new,
llama_grammar_stacks_cache & stacks_cache) {
stacks_new.clear();
stacks_new.reserve(stacks.size());

Expand All @@ -844,7 +944,7 @@ void llama_grammar_accept(
if (!llama_grammar_is_end_of_sequence(pos)) {
new_stack.push_back(pos);
}
llama_grammar_advance_stack(rules, new_stack, stacks_new);
llama_grammar_advance_stack_memo(rules, new_stack, stacks_new, stacks_cache);
}
}
}
Expand Down Expand Up @@ -938,14 +1038,15 @@ struct llama_grammar * llama_grammar_init_impl(

// loop over alternates of start rule to build initial stacks
llama_grammar_stacks stacks;
llama_grammar_stacks_cache stacks_cache;
pos = vec_rules[start_rule_index].data();
do {
llama_grammar_stack stack;
if (!llama_grammar_is_end_of_sequence(pos)) {
// if alternate is nonempty, add to stack
stack.push_back(pos);
}
llama_grammar_advance_stack(vec_rules, stack, stacks);
llama_grammar_advance_stack_memo(vec_rules, stack, stacks, stacks_cache);
while (!llama_grammar_is_end_of_sequence(pos)) {
// scan to end of alternate def
pos++;
Expand All @@ -961,7 +1062,7 @@ struct llama_grammar * llama_grammar_init_impl(
// Important: vec_rules has to be moved here, not copied, because stacks contains
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
// then the pointers would be invalidated when the local vec_rules goes out of scope.
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, };
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, std::move(stacks_cache), };
}

struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
Expand Down Expand Up @@ -1016,14 +1117,15 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab,

// loop over alternates of start rule to build initial stacks
llama_grammar_stacks stacks;
llama_grammar_stacks_cache stacks_cache;
pos = vec_rules[start_rule_index].data();
do {
llama_grammar_stack stack;
if (!llama_grammar_is_end_of_sequence(pos)) {
// if alternate is nonempty, add to stack
stack.push_back(pos);
}
llama_grammar_advance_stack(vec_rules, stack, stacks);
llama_grammar_advance_stack_memo(vec_rules, stack, stacks, stacks_cache);
while (!llama_grammar_is_end_of_sequence(pos)) {
// scan to end of alternate def
pos++;
Expand All @@ -1039,7 +1141,7 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab,
// Important: vec_rules has to be moved here, not copied, because stacks contains
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
// then the pointers would be invalidated when the local vec_rules goes out of scope.
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, };
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, std::move(stacks_cache), };
}

void llama_grammar_free_impl(struct llama_grammar * grammar) {
Expand Down Expand Up @@ -1129,7 +1231,7 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
llama_grammar_stacks stacks_new;

for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new);
llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new, grammar.stacks_cache);
grammar.stacks = std::move(stacks_new);
}

Expand Down
19 changes: 18 additions & 1 deletion src/llama-grammar.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "llama-impl.h"

#include <map>
#include <unordered_map>
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use unordered_map vs map? Is there any benefit or necessity of a sorted key? @ggerganov

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't think of any need to have a sorted key for this -- feels like unordered_map would be my default way to go on this one. If you're curious, a good set of profiling runs to test both options wouldn't be a bad exercise.


struct llama_vocab;

Expand Down Expand Up @@ -58,8 +59,21 @@ using llama_grammar_rules = std::vector<llama_grammar_rule>;
using llama_grammar_stacks = std::vector<llama_grammar_stack>;
using llama_grammar_candidates = std::vector<llama_grammar_candidate>;

struct VectorPointerHash {
size_t operator()(const llama_grammar_stack & v) const {
size_t seed = v.size();
for (const auto* ptr : v) {
seed ^= std::hash<const llama_grammar_element*>()(ptr) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
return seed;
}
};

using llama_grammar_stacks_cache = std::unordered_map<llama_grammar_stack, llama_grammar_stacks, VectorPointerHash>;

const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
llama_grammar_stacks_cache & llama_grammar_get_stacks_cache( struct llama_grammar * grammar);

// takes a set of possible pushdown stacks on a grammar, which are required to
// be positioned at a character range (see `llama_grammar_advance_stack`), and
Expand All @@ -69,7 +83,8 @@ void llama_grammar_accept(
const llama_grammar_rules & rules,
const llama_grammar_stacks & stacks,
uint32_t chr,
llama_grammar_stacks & stacks_new);
llama_grammar_stacks & stacks_new,
llama_grammar_stacks_cache & stacks_cache);

std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
const llama_grammar_rules & rules,
Expand Down Expand Up @@ -115,6 +130,8 @@ struct llama_grammar {

// buffer for partially generated UTF-8 sequence from accepted tokens
llama_partial_utf8 partial_utf8;
// cache N possible stacks from a stack
llama_grammar_stacks_cache stacks_cache;
};

//
Expand Down
3 changes: 2 additions & 1 deletion tests/test-grammar-integration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,12 @@ static bool match_string(const std::string & input, llama_grammar * grammar) {

const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
llama_grammar_stacks_cache & stacks_cache = llama_grammar_get_stacks_cache(grammar);

for (const auto & cpt : cpts) {
const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy

llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur);
llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur, stacks_cache);

if (stacks_cur.empty()) {
// no stacks means that the grammar failed to match at this point
Expand Down
Loading