|
4 | 4 |
|
5 | 5 | #include "llama.h" |
6 | 6 |
|
| 7 | +#include <queue> |
7 | 8 | #include <string> |
8 | 9 | #include <vector> |
9 | 10 | #include <sstream> |
| 11 | +#include <unordered_map> |
10 | 12 |
|
11 | 13 | #ifdef _WIN32 |
12 | 14 | #define DIRECTORY_SEPARATOR '\\' |
@@ -134,6 +136,7 @@ struct gpt_sampler_params { |
134 | 136 | }; |
135 | 137 |
|
136 | 138 | std::string grammar; // optional BNF-like grammar to constrain sampling |
| 139 | + std::vector<std::string> grammar_trigger_words; // optional trigger words to enable grammar |
137 | 140 |
|
138 | 141 | std::vector<llama_logit_bias> logit_bias; // logit biases to apply |
139 | 142 |
|
@@ -533,6 +536,201 @@ struct llama_control_vector_load_info { |
533 | 536 | // On error, returns {-1, empty} |
534 | 537 | llama_control_vector_data llama_control_vector_load(const std::vector<llama_control_vector_load_info> & load_infos); |
535 | 538 |
|
| 539 | +// |
| 540 | +// Antiprompt utils |
| 541 | +// |
| 542 | + |
| 543 | +class llama_antiprompts { |
| 544 | + public: |
| 545 | + |
| 546 | + struct llama_antiprompt { |
| 547 | + std::string value; |
| 548 | + bool is_grammar_trigger; |
| 549 | + }; |
| 550 | + |
| 551 | + std::vector<std::string> stop_words; |
| 552 | + std::vector<std::string> grammar_trigger_words; |
| 553 | + |
| 554 | +private: |
| 555 | + // The Aho–Corasick algorithm allows efficient string matching with multiple patterns. |
| 556 | + // See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm |
| 557 | + struct TrieNode { |
| 558 | + std::unordered_map<char, TrieNode> children; |
| 559 | + TrieNode* fail = nullptr; |
| 560 | + int output = -1; |
| 561 | + size_t depth = 0; |
| 562 | + |
| 563 | + void clear() { |
| 564 | + children.clear(); |
| 565 | + fail = nullptr; |
| 566 | + output = -1; |
| 567 | + depth = 0; |
| 568 | + } |
| 569 | + }; |
| 570 | + |
| 571 | + TrieNode root; |
| 572 | + std::vector<llama_antiprompt> antiprompts; |
| 573 | + std::unordered_map<llama_token, size_t> stop_tokens; // Single token antiprompts (and their index in antiprompts), if any. |
| 574 | + |
| 575 | + void build_trie() { |
| 576 | + // root = std::unique_ptr<TrieNode>(new TrieNode()); |
| 577 | + for (size_t i = 0; i < antiprompts.size(); ++i) { |
| 578 | + TrieNode* node = &root; |
| 579 | + const auto & pattern = antiprompts[i].value; |
| 580 | + for (size_t j = 0; j < pattern.length(); ++j) { |
| 581 | + char c = pattern[j]; |
| 582 | + auto & child = node->children[c]; |
| 583 | + if (child.depth == 0) { |
| 584 | + child.depth = j + 1; |
| 585 | + } |
| 586 | + node = &child; |
| 587 | + } |
| 588 | + node->output = i; |
| 589 | + } |
| 590 | + } |
| 591 | + |
| 592 | + void build_failure_and_dict_links() { |
| 593 | + std::queue<TrieNode*> q; |
| 594 | + for (auto& child : root.children) { |
| 595 | + child.second.fail = &root; |
| 596 | + q.push(&child.second); |
| 597 | + } |
| 598 | + |
| 599 | + while (!q.empty()) { |
| 600 | + auto node = q.front(); |
| 601 | + q.pop(); |
| 602 | + |
| 603 | + for (auto & pair : node->children) { |
| 604 | + auto & c = pair.first; |
| 605 | + auto & child = pair.second; |
| 606 | + auto f = node->fail; |
| 607 | + |
| 608 | + while (f != &root && f->children.find(c) == f->children.end()) { |
| 609 | + f = f->fail; |
| 610 | + } |
| 611 | + |
| 612 | + child.fail = (f == &root && f->children.find(c) == f->children.end()) |
| 613 | + ? &root : &f->children[c]; |
| 614 | + |
| 615 | + if (child.fail->output != -1) { |
| 616 | + child.output = child.fail->output; |
| 617 | + } |
| 618 | + |
| 619 | + q.push(&child); |
| 620 | + } |
| 621 | + } |
| 622 | + } |
| 623 | + |
| 624 | + public: |
| 625 | + |
| 626 | + bool empty() const { |
| 627 | + return antiprompts.empty() && stop_tokens.empty(); |
| 628 | + } |
| 629 | + void clear() { |
| 630 | + root.clear(); |
| 631 | + antiprompts.clear(); |
| 632 | + stop_tokens.clear(); |
| 633 | + } |
| 634 | + |
| 635 | + void build(const llama_context * ctx, const std::vector<std::string> & stop_words, const std::vector<std::string> & grammar_trigger_words) { |
| 636 | + build( |
| 637 | + [&](const std::string & text) { |
| 638 | + return llama_tokenize(ctx, text, /* special= */ true); |
| 639 | + }, |
| 640 | + stop_words, |
| 641 | + grammar_trigger_words |
| 642 | + ); |
| 643 | + } |
| 644 | + |
| 645 | + void build(const std::function<std::vector<llama_token>(const std::string)> & tokenizer, const std::vector<std::string> & stop_words, const std::vector<std::string> & grammar_trigger_words) { |
| 646 | + clear(); |
| 647 | + this->stop_words = stop_words; |
| 648 | + this->grammar_trigger_words = grammar_trigger_words; |
| 649 | + |
| 650 | + for (const std::string & stop_word : stop_words) { |
| 651 | + antiprompts.push_back({stop_word, /* is_grammar_trigger= */ false}); |
| 652 | + } |
| 653 | + for (const std::string & trigger : grammar_trigger_words) { |
| 654 | + antiprompts.push_back({trigger, /* is_grammar_trigger= */ true}); |
| 655 | + } |
| 656 | + |
| 657 | + for (size_t i = 0, n = antiprompts.size(); i < n; i++) { |
| 658 | + const auto & antiprompt = antiprompts[i]; |
| 659 | + std::vector<llama_token> tokens = tokenizer(antiprompt.value); |
| 660 | + if (tokens.size() == 1) { |
| 661 | + stop_tokens[tokens[0]] = i; |
| 662 | + } |
| 663 | + } |
| 664 | + |
| 665 | + build_trie(); |
| 666 | + build_failure_and_dict_links(); |
| 667 | + } |
| 668 | + |
| 669 | + struct MatchResult { |
| 670 | + size_t pos; |
| 671 | + std::string pattern; |
| 672 | + bool is_partial; |
| 673 | + size_t matchLength; |
| 674 | + bool is_grammar_trigger; |
| 675 | + |
| 676 | + bool operator==(const MatchResult & other) const { |
| 677 | + return pos == other.pos && pattern == other.pattern && is_partial == other.is_partial && matchLength == other.matchLength && is_grammar_trigger == other.is_grammar_trigger; |
| 678 | + } |
| 679 | + operator std::string() const { |
| 680 | + return "{pos=" + std::to_string(pos) + ", pattern=" + pattern + ", is_partial=" + std::to_string(is_partial) + ", matchLength=" + std::to_string(matchLength) + ", is_grammar_trigger=" + std::to_string(is_grammar_trigger) + "}"; |
| 681 | + } |
| 682 | + }; |
| 683 | + |
| 684 | + MatchResult findSingleTokenMatch(llama_token token) const { |
| 685 | + auto it = stop_tokens.find(token); |
| 686 | + if (it != stop_tokens.end()) { |
| 687 | + const auto & antiprompt = antiprompts[it->second]; |
| 688 | + return {0, antiprompt.value, false, antiprompt.value.length(), antiprompt.is_grammar_trigger}; |
| 689 | + } |
| 690 | + return {std::string::npos, "", false, 0, false}; |
| 691 | + } |
| 692 | + |
| 693 | + MatchResult findFirstMatch(const std::string& text, size_t offset = 0) { |
| 694 | + TrieNode* current = &root; |
| 695 | + MatchResult partialMatch{std::string::npos, "", true, 0, false}; |
| 696 | + |
| 697 | + for (size_t i = offset; i < text.length(); ++i) { |
| 698 | + char c = text[i]; |
| 699 | + while (current != &root && current->children.find(c) == current->children.end()) { |
| 700 | + current = current->fail; |
| 701 | + } |
| 702 | + auto it = current->children.find(c); |
| 703 | + if (it != current->children.end()) { |
| 704 | + current = &it->second; |
| 705 | + } |
| 706 | + if (current->output != -1) { |
| 707 | + const auto & antiprompt = antiprompts[current->output]; |
| 708 | + return { |
| 709 | + i - antiprompt.value.length() + 1, |
| 710 | + antiprompt.value, |
| 711 | + false, |
| 712 | + antiprompt.value.length(), |
| 713 | + antiprompt.is_grammar_trigger, |
| 714 | + }; |
| 715 | + } |
| 716 | + // Update partial match if we're at a deeper node |
| 717 | + if (current->depth > partialMatch.matchLength) { |
| 718 | + partialMatch.pos = i - current->depth + 1; |
| 719 | + partialMatch.pattern = ""; // We don't know which pattern it partially matches |
| 720 | + partialMatch.matchLength = current->depth; |
| 721 | + partialMatch.is_grammar_trigger = false; |
| 722 | + } |
| 723 | + } |
| 724 | + |
| 725 | + // If we've found a partial match and haven't returned a full match, return the partial match |
| 726 | + if (partialMatch.pos != std::string::npos) { |
| 727 | + return partialMatch; |
| 728 | + } |
| 729 | + |
| 730 | + return {std::string::npos, "", false, 0, false}; |
| 731 | + } |
| 732 | +}; |
| 733 | + |
536 | 734 | // |
537 | 735 | // Split utils |
538 | 736 | // |
|
0 commit comments