Skip to content

Commit 16c9c63

Browse files
author
ochafik
committed
add common_regex w/ support for partial final matches
1 parent e128a1b commit 16c9c63

File tree

5 files changed

+478
-0
lines changed

5 files changed

+478
-0
lines changed

common/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ add_library(${TARGET} STATIC
7171
minja/minja.hpp
7272
ngram-cache.cpp
7373
ngram-cache.h
74+
regex-partial.cpp
75+
regex-partial.h
7476
sampling.cpp
7577
sampling.h
7678
speculative.cpp

common/regex-partial.cpp

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
#include "regex-partial.h"
2+
#include "common.h"
3+
#include <functional>
4+
5+
common_regex::common_regex(const std::string & pattern, bool at_start) :
6+
pattern(pattern),
7+
rx(pattern),
8+
rx_reversed_partial(regex_to_reversed_partial_regex(pattern)),
9+
at_start_(at_start) {}
10+
11+
common_regex_match common_regex::search(const std::string & input, size_t pos, bool as_match) const {
12+
std::smatch match;
13+
if (pos > input.size()) {
14+
throw std::runtime_error("Position out of bounds");
15+
}
16+
auto start = input.begin() + pos;
17+
auto found = as_match
18+
? std::regex_match(start, input.end(), match, rx)
19+
: std::regex_search(start, input.end(), match, rx);
20+
if (found) {
21+
if (as_match || !at_start_ || match.position() == 0) {
22+
common_regex_match res;
23+
res.type = COMMON_REGEX_MATCH_TYPE_FULL;
24+
for (size_t i = 0; i < match.size(); ++i) {
25+
common_string_range group;
26+
group.begin = pos + match.position(i);
27+
group.end = group.begin + match.length(i);
28+
res.groups.push_back(group);
29+
}
30+
return res;
31+
}
32+
}
33+
std::match_results<std::string::const_reverse_iterator> srmatch;
34+
if (std::regex_match(input.rbegin(), input.rend() - pos, srmatch, rx_reversed_partial)) {
35+
auto group = srmatch[1].str();
36+
auto it = srmatch[1].second.base();
37+
// auto position = static_cast<size_t>(std::distance(input.begin(), it));
38+
if ((!as_match && !at_start_) || it == input.begin()) {
39+
common_regex_match res;
40+
res.type = COMMON_REGEX_MATCH_TYPE_PARTIAL;
41+
//res.groups.push_back({input.substr(position), position, input.size()});
42+
res.groups.push_back({pos + std::distance(input.begin(), it), input.size()});
43+
return res;
44+
}
45+
}
46+
return {};
47+
}
48+
49+
/*
50+
Transforms a regex pattern to a partial match pattern that operates on a reversed input string to find partial final matches of the original pattern.
51+
52+
Ideally we'd like to use boost::match_partial (https://beta.boost.org/doc/libs/1_59_0/libs/regex/doc/html/boost_regex/partial_matches.html)
53+
to see if a string ends with a partial regex match, but but it's not in std::regex yet.
54+
Instead, we'll the regex into a partial match regex operating as a full match on the reverse iterators of the input.
55+
56+
- /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:(?:d)?c)?b)?a).*
57+
- /a|b/ -> (a|b).*
58+
- /a*?/ -> error, could match ""
59+
- /a*b/ -> ((?:b)?a*+).* (final repetitions become eager)
60+
- /.*?ab/ -> ((?:b)?a).* (merge .*)
61+
- /a.*?b/ -> ((?:b)?.*?a).* (keep reluctant matches)
62+
- /a.*b/ -> ((?:b)?.*?a).* (in fact any repetition becomes a reluctant match!)
63+
- /a(bc)d/ -> ((?:(?:d)?(?:(?:c)?b))?a).*
64+
- /a(bc|de)/ -> ((?:(?:(?:e)?d)?|(?:(?:c)?b)?)?a).*
65+
- /ab{2,4}c/ -> abbb?b?c -> ((?:(?:(?:(?:(?:c)?b)?b)?b?)?b?)?a).*
66+
67+
The regex will match a reversed string fully, and the end of the first (And only) capturing group will indicate the reversed start of the original partial pattern
68+
(i.e. just where the final .* starts in the inverted pattern; all other groups are turned into non-capturing groups, and reluctant quantifiers are ignored)
69+
*/
70+
std::string regex_to_reversed_partial_regex(const std::string &pattern) {
71+
auto it = pattern.begin();
72+
const auto end = pattern.end();
73+
74+
std::function<std::string()> process = [&]() {
75+
std::vector<std::vector<std::string>> alternatives(1);
76+
std::vector<std::string> * sequence = &alternatives.back();
77+
78+
while (it != end) {
79+
if (*it == '[') {
80+
auto start = it;
81+
++it;
82+
while (it != end) {
83+
if (*it == '\\' && (++it != end)) {
84+
++it;
85+
} else if (*it == ']') {
86+
break;
87+
} else {
88+
++it;
89+
}
90+
}
91+
if (it == end) {
92+
throw std::runtime_error("Unmatched '[' in pattern");
93+
}
94+
++it;
95+
sequence->push_back(std::string(start, it));
96+
} else if (*it == '*' || *it == '?' || *it == '+') {
97+
if (sequence->empty()) {
98+
throw std::runtime_error("Quantifier without preceding element");
99+
}
100+
sequence->back() += *it;
101+
auto is_star = *it == '*';
102+
++it;
103+
if (is_star) {
104+
if (*it == '?') {
105+
++it;
106+
// Convert initial reluctant quantifier to greedy to match as early as possible
107+
if (sequence->size() > 1) {
108+
sequence->back() += '?';
109+
}
110+
} else {
111+
// Convert greedy quantifiers to reluctant to not miss any matches
112+
sequence->back() += '?';
113+
}
114+
}
115+
} else if (*it == '{') {
116+
if (sequence->empty()) {
117+
throw std::runtime_error("Repetition without preceding element");
118+
}
119+
++it;
120+
auto start = it;
121+
while (it != end && *it != '}') {
122+
++it;
123+
}
124+
if (it == end) {
125+
throw std::runtime_error("Unmatched '{' in pattern");
126+
}
127+
auto parts = string_split(std::string(start, it), ",");
128+
++it;
129+
if (parts.size() > 2) {
130+
throw std::runtime_error("Invalid repetition range in pattern");
131+
}
132+
133+
auto parseOptInt = [&](const std::string & s, const std::optional<int> & def = std::nullopt) -> std::optional<int> {
134+
if (s.empty()) {
135+
return def;
136+
}
137+
return std::stoi(s);
138+
};
139+
auto min = parseOptInt(parts[0], 0);
140+
auto max = parts.size() == 1 ? min : parseOptInt(parts[1]);
141+
if (min && max && *max < *min) {
142+
throw std::runtime_error("Invalid repetition range in pattern");
143+
}
144+
// Brutal but... let's repeat at least min times, then ? for the delta between min & max (or * for unbounded)
145+
auto part = sequence->back();
146+
sequence->pop_back();
147+
for (int i = 0; i < *min; i++) {
148+
sequence->push_back(part);
149+
}
150+
if (max) {
151+
for (int i = *min; i < *max; i++) {
152+
sequence->push_back(part + "?");
153+
}
154+
} else {
155+
sequence->push_back(part + "*");
156+
}
157+
} else if (*it == '(') {
158+
++it;
159+
if (it != end && *it == '?' && (it + 1 != end) && *(it + 1) == ':') {
160+
it += 2;
161+
}
162+
auto sub = process();
163+
if (*it != ')') {
164+
throw std::runtime_error("Unmatched '(' in pattern");
165+
}
166+
++it;
167+
auto & part = sequence->emplace_back("(?:");
168+
part += sub;
169+
part += ")";
170+
} else if (*it == ')') {
171+
break;
172+
} else if (*it == '|') {
173+
++it;
174+
alternatives.emplace_back();
175+
sequence = &alternatives.back();
176+
} else if (*it == '\\' && (++it != end)) {
177+
auto str = std::string("\\") + *it;
178+
sequence->push_back(str);
179+
++it;
180+
} else {
181+
sequence->push_back(std::string(1, *it));
182+
++it;
183+
}
184+
}
185+
186+
// /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:d)?c)?b)?a).*
187+
// if n(=4) parts, opening n-1(=3) non-capturing groups after the 1 capturing group
188+
// We'll do the outermost capturing group and final .* in the enclosing function.
189+
std::vector<std::string> res_alts;
190+
for (const auto & parts : alternatives) {
191+
auto & res = res_alts.emplace_back();
192+
for (size_t i = 0; i < parts.size() - 1; i++) {
193+
res += "(?:";
194+
}
195+
for (auto it = parts.rbegin(); it != parts.rend(); ++it) {
196+
res += *it;
197+
if (it != parts.rend() - 1) {
198+
res += ")?";
199+
}
200+
}
201+
}
202+
return string_join(res_alts, "|");
203+
};
204+
auto res = process();
205+
if (it != end) {
206+
throw std::runtime_error("Unmatched '(' in pattern");
207+
}
208+
209+
return "(" + res + ").*";
210+
}

common/regex-partial.h

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#pragma once
2+
3+
#include <regex>
4+
#include <string>
5+
6+
enum common_regex_match_type {
7+
COMMON_REGEX_MATCH_TYPE_NONE,
8+
COMMON_REGEX_MATCH_TYPE_PARTIAL,
9+
COMMON_REGEX_MATCH_TYPE_FULL,
10+
};
11+
12+
struct common_string_range {
13+
size_t begin;
14+
size_t end;
15+
bool empty() const {
16+
return begin == end;
17+
}
18+
bool operator==(const common_string_range & other) const {
19+
return begin == other.begin && end == other.end;
20+
}
21+
};
22+
23+
struct common_regex_match {
24+
common_regex_match_type type = COMMON_REGEX_MATCH_TYPE_NONE;
25+
std::vector<common_string_range> groups;
26+
27+
bool operator==(const common_regex_match & other) const {
28+
return type == other.type && groups == other.groups;
29+
}
30+
bool operator!=(const common_regex_match & other) const {
31+
return !(*this == other);
32+
}
33+
};
34+
35+
class common_regex {
36+
std::string pattern;
37+
std::regex rx;
38+
std::regex rx_reversed_partial;
39+
bool at_start_;
40+
41+
public:
42+
common_regex(const std::string & pattern, bool at_start = false);
43+
44+
common_regex_match search(const std::string & input, size_t pos, bool as_match = false) const;
45+
46+
const std::string & str() const { return pattern; }
47+
bool at_start() const { return at_start_; }
48+
};
49+
50+
// For testing only (pretty print of failures).
51+
std::string regex_to_reversed_partial_regex(const std::string &pattern);

tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ endif()
133133
llama_target_and_test(test-log.cpp)
134134
llama_target_and_test(test-arg-parser.cpp)
135135
llama_target_and_test(test-chat-template.cpp)
136+
llama_target_and_test(test-regex-partial.cpp)
136137

137138
# llama_target_and_test(test-opt.cpp) # SLOW
138139
llama_target_and_test(test-gguf.cpp)

0 commit comments

Comments
 (0)