-
Notifications
You must be signed in to change notification settings - Fork 13.4k
common: add partial regex support
#12808
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
59b87c5
ff35374
869e1a9
6f109fa
908e12f
868b442
2ea5f5c
b275da3
9b620e5
5c99bdc
e051be6
afce553
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -443,6 +443,26 @@ void string_replace_all(std::string & s, const std::string & search, const std:: | |
| s = std::move(builder); | ||
| } | ||
|
|
||
| bool string_ends_with(const std::string & str, const std::string & suffix) { | ||
| return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0; | ||
| } | ||
|
|
||
| size_t string_find_partial_stop(const std::string &str, const std::string &stop) { | ||
| if (!str.empty() && !stop.empty()) { | ||
| const char text_last_char = str.back(); | ||
| for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) { | ||
| if (stop[char_index] == text_last_char) { | ||
| const std::string current_partial = stop.substr(0, char_index + 1); | ||
| if (string_ends_with(str, current_partial)) { | ||
| return str.size() - char_index - 1; | ||
| } | ||
| } | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think these sub-string copies can be avoided via a |
||
| } | ||
| } | ||
|
|
||
| return std::string::npos; | ||
| } | ||
|
|
||
| std::string regex_escape(const std::string & s) { | ||
| static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]"); | ||
| return std::regex_replace(s, special_chars, "\\$0"); | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -499,10 +499,9 @@ static bool string_starts_with(const std::string & str, | |||||
| return str.rfind(prefix, 0) == 0; | ||||||
| } | ||||||
|
|
||||||
| static bool string_ends_with(const std::string & str, | ||||||
| const std::string & suffix) { // While we wait for C++20's std::string::ends_with... | ||||||
| return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0; | ||||||
| } | ||||||
| // While we wait for C++20's std::string::ends_with... | ||||||
| bool string_ends_with(const std::string & str, const std::string & suffix); | ||||||
| size_t string_find_partial_stop(const std::string &str, const std::string &stop); | ||||||
|
||||||
| size_t string_find_partial_stop(const std::string &str, const std::string &stop); | |
| size_t string_find_partial_stop(const std::string & str, const std::string & stop); |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,203 @@ | ||
| #include "regex-partial.h" | ||
| #include "common.h" | ||
| #include <functional> | ||
| #include <optional> | ||
|
|
||
| common_regex::common_regex(const std::string & pattern) : | ||
| pattern(pattern), | ||
| rx(pattern), | ||
| rx_reversed_partial(regex_to_reversed_partial_regex(pattern)) {} | ||
|
|
||
| common_regex_match common_regex::search(const std::string & input, size_t pos, bool as_match) const { | ||
| std::smatch match; | ||
| if (pos > input.size()) { | ||
| throw std::runtime_error("Position out of bounds"); | ||
| } | ||
| auto start = input.begin() + pos; | ||
| auto found = as_match | ||
| ? std::regex_match(start, input.end(), match, rx) | ||
| : std::regex_search(start, input.end(), match, rx); | ||
| if (found) { | ||
| common_regex_match res; | ||
| res.type = COMMON_REGEX_MATCH_TYPE_FULL; | ||
| for (size_t i = 0; i < match.size(); ++i) { | ||
| auto begin = pos + match.position(i); | ||
| res.groups.emplace_back(begin, begin + match.length(i)); | ||
| } | ||
| return res; | ||
| } | ||
| std::match_results<std::string::const_reverse_iterator> srmatch; | ||
| if (std::regex_match(input.rbegin(), input.rend() - pos, srmatch, rx_reversed_partial)) { | ||
| auto group = srmatch[1].str(); | ||
| if (group.length() != 0) { | ||
| auto it = srmatch[1].second.base(); | ||
| // auto position = static_cast<size_t>(std::distance(input.begin(), it)); | ||
| if ((!as_match) || it == input.begin()) { | ||
| common_regex_match res; | ||
| res.type = COMMON_REGEX_MATCH_TYPE_PARTIAL; | ||
| auto begin = std::distance(input.begin(), it); | ||
| GGML_ASSERT(begin >= 0); | ||
| auto end = input.size();//begin + group.length(); | ||
| GGML_ASSERT(static_cast<size_t>(begin) <= end); | ||
| res.groups.push_back({static_cast<size_t>(begin), end}); | ||
| return res; | ||
ochafik marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
| } | ||
| } | ||
| return {}; | ||
| } | ||
|
|
||
| /* | ||
| Transforms a regex pattern to a partial match pattern that operates on a reversed input string to find partial final matches of the original pattern. | ||
| Ideally we'd like to use boost::match_partial (https://beta.boost.org/doc/libs/1_59_0/libs/regex/doc/html/boost_regex/partial_matches.html) | ||
| to see if a string ends with a partial regex match, but but it's not in std::regex yet. | ||
| Instead, we'll the regex into a partial match regex operating as a full match on the reverse iterators of the input. | ||
| - /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:(?:d)?c)?b)?a).* | ||
| - /a|b/ -> (a|b).* | ||
| - /a*?/ -> error, could match "" | ||
| - /a*b/ -> ((?:b)?a*+).* (final repetitions become eager) | ||
| - /.*?ab/ -> ((?:b)?a).* (merge .*) | ||
| - /a.*?b/ -> ((?:b)?.*?a).* (keep reluctant matches) | ||
| - /a(bc)d/ -> ((?:(?:d)?(?:(?:c)?b))?a).* | ||
| - /a(bc|de)/ -> ((?:(?:(?:e)?d)?|(?:(?:c)?b)?)?a).* | ||
| - /ab{2,4}c/ -> abbb?b?c -> ((?:(?:(?:(?:(?:c)?b)?b)?b?)?b?)?a).* | ||
| The regex will match a reversed string fully, and the end of the first (And only) capturing group will indicate the reversed start of the original partial pattern | ||
| (i.e. just where the final .* starts in the inverted pattern; all other groups are turned into non-capturing groups, and reluctant quantifiers are ignored) | ||
| */ | ||
| std::string regex_to_reversed_partial_regex(const std::string &pattern) { | ||
ochafik marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| auto it = pattern.begin(); | ||
| const auto end = pattern.end(); | ||
|
|
||
| std::function<std::string()> process = [&]() { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Btw I think we should prevent using lambda function when possible (especially long functions), the bad thing is that it can increase compilation time by a lot. A cleaner solution could be do have a |
||
| std::vector<std::vector<std::string>> alternatives(1); | ||
| std::vector<std::string> * sequence = &alternatives.back(); | ||
|
|
||
| while (it != end) { | ||
| if (*it == '[') { | ||
| auto start = it; | ||
| ++it; | ||
| while (it != end) { | ||
| if (*it == '\\' && (++it != end)) { | ||
| ++it; | ||
| } else if (*it == ']') { | ||
|
||
| break; | ||
| } else { | ||
| ++it; | ||
| } | ||
| } | ||
| if (it == end) { | ||
| throw std::runtime_error("Unmatched '[' in pattern"); | ||
| } | ||
| ++it; | ||
| sequence->push_back(std::string(start, it)); | ||
| } else if (*it == '*' || *it == '?' || *it == '+') { | ||
| if (sequence->empty()) { | ||
| throw std::runtime_error("Quantifier without preceding element"); | ||
| } | ||
| sequence->back() += *it; | ||
| auto is_star = *it == '*'; | ||
| ++it; | ||
| if (is_star) { | ||
| if (*it == '?') { | ||
| ++it; | ||
| } | ||
| } | ||
| } else if (*it == '{') { | ||
| if (sequence->empty()) { | ||
| throw std::runtime_error("Repetition without preceding element"); | ||
| } | ||
| ++it; | ||
| auto start = it; | ||
| while (it != end && *it != '}') { | ||
| ++it; | ||
| } | ||
| if (it == end) { | ||
| throw std::runtime_error("Unmatched '{' in pattern"); | ||
| } | ||
| auto parts = string_split(std::string(start, it), ","); | ||
| ++it; | ||
| if (parts.size() > 2) { | ||
| throw std::runtime_error("Invalid repetition range in pattern"); | ||
| } | ||
|
|
||
| auto parseOptInt = [&](const std::string & s, const std::optional<int> & def = std::nullopt) -> std::optional<int> { | ||
| if (s.empty()) { | ||
| return def; | ||
| } | ||
| return std::stoi(s); | ||
| }; | ||
| auto min = parseOptInt(parts[0], 0); | ||
| auto max = parts.size() == 1 ? min : parseOptInt(parts[1]); | ||
| if (min && max && *max < *min) { | ||
| throw std::runtime_error("Invalid repetition range in pattern"); | ||
| } | ||
| // Brutal but... let's repeat at least min times, then ? for the delta between min & max (or * for unbounded) | ||
| auto part = sequence->back(); | ||
| sequence->pop_back(); | ||
| for (int i = 0; i < *min; i++) { | ||
| sequence->push_back(part); | ||
| } | ||
| if (max) { | ||
| for (int i = *min; i < *max; i++) { | ||
| sequence->push_back(part + "?"); | ||
| } | ||
| } else { | ||
| sequence->push_back(part + "*"); | ||
| } | ||
| } else if (*it == '(') { | ||
| ++it; | ||
| if (it != end && *it == '?' && (it + 1 != end) && *(it + 1) == ':') { | ||
| it += 2; | ||
| } | ||
| auto sub = process(); | ||
| if (*it != ')') { | ||
| throw std::runtime_error("Unmatched '(' in pattern"); | ||
| } | ||
| ++it; | ||
| auto & part = sequence->emplace_back("(?:"); | ||
| part += sub; | ||
| part += ")"; | ||
| } else if (*it == ')') { | ||
| break; | ||
| } else if (*it == '|') { | ||
| ++it; | ||
| alternatives.emplace_back(); | ||
| sequence = &alternatives.back(); | ||
| } else if (*it == '\\' && (++it != end)) { | ||
| auto str = std::string("\\") + *it; | ||
| sequence->push_back(str); | ||
| ++it; | ||
| } else { | ||
| sequence->push_back(std::string(1, *it)); | ||
| ++it; | ||
| } | ||
| } | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar problem with potentially dereferencing |
||
|
|
||
| // /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:d)?c)?b)?a).* | ||
| // if n(=4) parts, opening n-1(=3) non-capturing groups after the 1 capturing group | ||
| // We'll do the outermost capturing group and final .* in the enclosing function. | ||
| std::vector<std::string> res_alts; | ||
| for (const auto & parts : alternatives) { | ||
| auto & res = res_alts.emplace_back(); | ||
| for (size_t i = 0; i < parts.size() - 1; i++) { | ||
| res += "(?:"; | ||
| } | ||
| for (auto it = parts.rbegin(); it != parts.rend(); ++it) { | ||
| res += *it; | ||
| if (it != parts.rend() - 1) { | ||
| res += ")?"; | ||
| } | ||
| } | ||
| } | ||
| return string_join(res_alts, "|"); | ||
| }; | ||
| auto res = process(); | ||
| if (it != end) { | ||
| throw std::runtime_error("Unmatched '(' in pattern"); | ||
| } | ||
|
|
||
| return "(" + res + ")[\\s\\S]*"; | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,55 @@ | ||
| #pragma once | ||
|
|
||
| #include <regex> | ||
| #include <string> | ||
| #include "ggml.h" | ||
|
||
|
|
||
| enum common_regex_match_type { | ||
| COMMON_REGEX_MATCH_TYPE_NONE, | ||
| COMMON_REGEX_MATCH_TYPE_PARTIAL, | ||
| COMMON_REGEX_MATCH_TYPE_FULL, | ||
| }; | ||
|
|
||
| struct common_string_range { | ||
| size_t begin; | ||
| size_t end; | ||
| common_string_range(size_t begin, size_t end) : begin(begin), end(end) { | ||
| GGML_ASSERT(begin <= end); | ||
| } | ||
| // prevent default ctor | ||
| common_string_range() = delete; | ||
| bool empty() const { | ||
| return begin == end; | ||
| } | ||
| bool operator==(const common_string_range & other) const { | ||
| return begin == other.begin && end == other.end; | ||
| } | ||
| }; | ||
|
|
||
| struct common_regex_match { | ||
| common_regex_match_type type = COMMON_REGEX_MATCH_TYPE_NONE; | ||
| std::vector<common_string_range> groups; | ||
|
|
||
| bool operator==(const common_regex_match & other) const { | ||
| return type == other.type && groups == other.groups; | ||
| } | ||
| bool operator!=(const common_regex_match & other) const { | ||
| return !(*this == other); | ||
| } | ||
| }; | ||
|
|
||
| class common_regex { | ||
| std::string pattern; | ||
| std::regex rx; | ||
| std::regex rx_reversed_partial; | ||
|
|
||
| public: | ||
| explicit common_regex(const std::string & pattern); | ||
|
|
||
| common_regex_match search(const std::string & input, size_t pos, bool as_match = false) const; | ||
|
|
||
| const std::string & str() const { return pattern; } | ||
| }; | ||
|
|
||
| // For testing only (pretty print of failures). | ||
| std::string regex_to_reversed_partial_regex(const std::string &pattern); | ||
ochafik marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
Uh oh!
There was an error while loading. Please reload this page.