Skip to content

Commit 5b6d504

Browse files
author
ochafik
committed
grammar: trigger words + refactor of antiprompts
1 parent 70392f1 commit 5b6d504

File tree

12 files changed

+436
-108
lines changed

12 files changed

+436
-108
lines changed

Makefile

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ BUILD_TARGETS = \
4444

4545
# Binaries only useful for tests
4646
TEST_TARGETS = \
47+
tests/test-antiprompts \
4748
tests/test-arg-parser \
4849
tests/test-autorelease \
4950
tests/test-backend-ops \
@@ -1567,6 +1568,11 @@ tests/test-json-schema-to-grammar: tests/test-json-schema-to-grammar.cpp \
15671568
$(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<)
15681569
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
15691570

1571+
tests/test-antiprompts: tests/test-antiprompts.cpp \
1572+
$(OBJ_ALL)
1573+
$(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<)
1574+
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
1575+
15701576
tests/test-grad0: tests/test-grad0.cpp \
15711577
$(OBJ_GGML)
15721578
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)

common/common.h

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44

55
#include "llama.h"
66

7+
#include <queue>
78
#include <string>
89
#include <vector>
910
#include <sstream>
11+
#include <unordered_map>
1012

1113
#ifdef _WIN32
1214
#define DIRECTORY_SEPARATOR '\\'
@@ -134,6 +136,7 @@ struct gpt_sampler_params {
134136
};
135137

136138
std::string grammar; // optional BNF-like grammar to constrain sampling
139+
std::vector<std::string> grammar_trigger_words; // optional trigger words to enable grammar
137140

138141
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
139142

@@ -533,6 +536,201 @@ struct llama_control_vector_load_info {
533536
// On error, returns {-1, empty}
534537
llama_control_vector_data llama_control_vector_load(const std::vector<llama_control_vector_load_info> & load_infos);
535538

539+
//
540+
// Antiprompt utils
541+
//
542+
543+
class llama_antiprompts {
544+
public:
545+
546+
struct llama_antiprompt {
547+
std::string value;
548+
bool is_grammar_trigger;
549+
};
550+
551+
std::vector<std::string> stop_words;
552+
std::vector<std::string> grammar_trigger_words;
553+
554+
private:
555+
// The Aho–Corasick algorithm allows efficient string matching with multiple patterns.
556+
// See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm
557+
struct TrieNode {
558+
std::unordered_map<char, TrieNode> children;
559+
TrieNode* fail = nullptr;
560+
int output = -1;
561+
size_t depth = 0;
562+
563+
void clear() {
564+
children.clear();
565+
fail = nullptr;
566+
output = -1;
567+
depth = 0;
568+
}
569+
};
570+
571+
TrieNode root;
572+
std::vector<llama_antiprompt> antiprompts;
573+
std::unordered_map<llama_token, size_t> stop_tokens; // Single token antiprompts (and their index in antiprompts), if any.
574+
575+
void build_trie() {
576+
// root = std::unique_ptr<TrieNode>(new TrieNode());
577+
for (size_t i = 0; i < antiprompts.size(); ++i) {
578+
TrieNode* node = &root;
579+
const auto & pattern = antiprompts[i].value;
580+
for (size_t j = 0; j < pattern.length(); ++j) {
581+
char c = pattern[j];
582+
auto & child = node->children[c];
583+
if (child.depth == 0) {
584+
child.depth = j + 1;
585+
}
586+
node = &child;
587+
}
588+
node->output = i;
589+
}
590+
}
591+
592+
void build_failure_and_dict_links() {
593+
std::queue<TrieNode*> q;
594+
for (auto& child : root.children) {
595+
child.second.fail = &root;
596+
q.push(&child.second);
597+
}
598+
599+
while (!q.empty()) {
600+
auto node = q.front();
601+
q.pop();
602+
603+
for (auto & pair : node->children) {
604+
auto & c = pair.first;
605+
auto & child = pair.second;
606+
auto f = node->fail;
607+
608+
while (f != &root && f->children.find(c) == f->children.end()) {
609+
f = f->fail;
610+
}
611+
612+
child.fail = (f == &root && f->children.find(c) == f->children.end())
613+
? &root : &f->children[c];
614+
615+
if (child.fail->output != -1) {
616+
child.output = child.fail->output;
617+
}
618+
619+
q.push(&child);
620+
}
621+
}
622+
}
623+
624+
public:
625+
626+
bool empty() const {
627+
return antiprompts.empty() && stop_tokens.empty();
628+
}
629+
void clear() {
630+
root.clear();
631+
antiprompts.clear();
632+
stop_tokens.clear();
633+
}
634+
635+
void build(const llama_context * ctx, const std::vector<std::string> & stop_words, const std::vector<std::string> & grammar_trigger_words) {
636+
build(
637+
[&](const std::string & text) {
638+
return llama_tokenize(ctx, text, /* special= */ true);
639+
},
640+
stop_words,
641+
grammar_trigger_words
642+
);
643+
}
644+
645+
void build(const std::function<std::vector<llama_token>(const std::string)> & tokenizer, const std::vector<std::string> & stop_words, const std::vector<std::string> & grammar_trigger_words) {
646+
clear();
647+
this->stop_words = stop_words;
648+
this->grammar_trigger_words = grammar_trigger_words;
649+
650+
for (const std::string & stop_word : stop_words) {
651+
antiprompts.push_back({stop_word, /* is_grammar_trigger= */ false});
652+
}
653+
for (const std::string & trigger : grammar_trigger_words) {
654+
antiprompts.push_back({trigger, /* is_grammar_trigger= */ true});
655+
}
656+
657+
for (size_t i = 0, n = antiprompts.size(); i < n; i++) {
658+
const auto & antiprompt = antiprompts[i];
659+
std::vector<llama_token> tokens = tokenizer(antiprompt.value);
660+
if (tokens.size() == 1) {
661+
stop_tokens[tokens[0]] = i;
662+
}
663+
}
664+
665+
build_trie();
666+
build_failure_and_dict_links();
667+
}
668+
669+
struct MatchResult {
670+
size_t pos;
671+
std::string pattern;
672+
bool is_partial;
673+
size_t matchLength;
674+
bool is_grammar_trigger;
675+
676+
bool operator==(const MatchResult & other) const {
677+
return pos == other.pos && pattern == other.pattern && is_partial == other.is_partial && matchLength == other.matchLength && is_grammar_trigger == other.is_grammar_trigger;
678+
}
679+
operator std::string() const {
680+
return "{pos=" + std::to_string(pos) + ", pattern=" + pattern + ", is_partial=" + std::to_string(is_partial) + ", matchLength=" + std::to_string(matchLength) + ", is_grammar_trigger=" + std::to_string(is_grammar_trigger) + "}";
681+
}
682+
};
683+
684+
MatchResult findSingleTokenMatch(llama_token token) const {
685+
auto it = stop_tokens.find(token);
686+
if (it != stop_tokens.end()) {
687+
const auto & antiprompt = antiprompts[it->second];
688+
return {0, antiprompt.value, false, antiprompt.value.length(), antiprompt.is_grammar_trigger};
689+
}
690+
return {std::string::npos, "", false, 0, false};
691+
}
692+
693+
MatchResult findFirstMatch(const std::string& text, size_t offset = 0) {
694+
TrieNode* current = &root;
695+
MatchResult partialMatch{std::string::npos, "", true, 0, false};
696+
697+
for (size_t i = offset; i < text.length(); ++i) {
698+
char c = text[i];
699+
while (current != &root && current->children.find(c) == current->children.end()) {
700+
current = current->fail;
701+
}
702+
auto it = current->children.find(c);
703+
if (it != current->children.end()) {
704+
current = &it->second;
705+
}
706+
if (current->output != -1) {
707+
const auto & antiprompt = antiprompts[current->output];
708+
return {
709+
i - antiprompt.value.length() + 1,
710+
antiprompt.value,
711+
false,
712+
antiprompt.value.length(),
713+
antiprompt.is_grammar_trigger,
714+
};
715+
}
716+
// Update partial match if we're at a deeper node
717+
if (current->depth > partialMatch.matchLength) {
718+
partialMatch.pos = i - current->depth + 1;
719+
partialMatch.pattern = ""; // We don't know which pattern it partially matches
720+
partialMatch.matchLength = current->depth;
721+
partialMatch.is_grammar_trigger = false;
722+
}
723+
}
724+
725+
// If we've found a partial match and haven't returned a full match, return the partial match
726+
if (partialMatch.pos != std::string::npos) {
727+
return partialMatch;
728+
}
729+
730+
return {std::string::npos, "", false, 0, false};
731+
}
732+
};
733+
536734
//
537735
// Split utils
538736
//

common/sampling.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,14 +139,23 @@ std::string gpt_sampler_params::print() const {
139139
return std::string(result);
140140
}
141141

142+
bool gpt_sampler_trigger_grammar(const struct llama_model * model, gpt_sampler * gsmpl, const std::string & trigger) {
143+
if (gsmpl->grmr) {
144+
return false;
145+
}
146+
gsmpl->grmr = llama_sampler_init_grammar(model, gsmpl->params.grammar.c_str(), "root");
147+
llama_sampler_accept_str(gsmpl->grmr, trigger.c_str());
148+
return true;
149+
}
150+
142151
struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params) {
143152
llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
144153

145154
lparams.no_perf = params.no_perf;
146155

147156
auto * result = new gpt_sampler {
148157
/* .params = */ params,
149-
/* .grmr = */ llama_sampler_init_grammar(model, params.grammar.c_str(), "root"),
158+
/* .grmr = */ params.grammar_trigger_words.empty() ? llama_sampler_init_grammar(model, params.grammar.c_str(), "root") : nullptr,
150159
/* .chain = */ llama_sampler_chain_init(lparams),
151160
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
152161
/* .cur = */ {},
@@ -226,7 +235,9 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
226235

227236
void gpt_sampler_free(struct gpt_sampler * gsmpl) {
228237
if (gsmpl) {
229-
llama_sampler_free(gsmpl->grmr);
238+
if (gsmpl->grmr) {
239+
llama_sampler_free(gsmpl->grmr);
240+
}
230241

231242
llama_sampler_free(gsmpl->chain);
232243

common/sampling.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,5 +79,7 @@ std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx, int n
7979
char gpt_sampler_type_to_chr(enum gpt_sampler_type cnstr);
8080
std::string gpt_sampler_type_to_str(enum gpt_sampler_type cnstr);
8181

82+
bool gpt_sampler_trigger_grammar(const struct llama_model * model, gpt_sampler * gsmpl, const std::string & trigger);
83+
8284
std::vector<enum gpt_sampler_type> gpt_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
8385
std::vector<enum gpt_sampler_type> gpt_sampler_types_from_chars(const std::string & chars);

0 commit comments

Comments
 (0)