Skip to content

Commit 5268ec8

Browse files
author
Olivier Chafik
committed
Refactor string helpers into common
1 parent d77fecc commit 5268ec8

File tree

5 files changed

+64
-64
lines changed

5 files changed

+64
-64
lines changed

common/common.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,48 @@ void string_replace_all(std::string & s, const std::string & search, const std::
485485
s = std::move(builder);
486486
}
487487

488+
std::string string_join(const std::vector<std::string> & values, const std::string & separator) {
489+
std::ostringstream result;
490+
for (size_t i = 0; i < values.size(); ++i) {
491+
if (i > 0) {
492+
result << separator;
493+
}
494+
result << values[i];
495+
}
496+
return result.str();
497+
}
498+
499+
std::vector<std::string> string_split(const std::string & str, const std::string & delimiter) {
500+
std::vector<std::string> tokens;
501+
size_t start = 0;
502+
size_t end = str.find(delimiter);
503+
504+
while (end != std::string::npos) {
505+
tokens.push_back(str.substr(start, end - start));
506+
start = end + delimiter.length();
507+
end = str.find(delimiter, start);
508+
}
509+
510+
tokens.push_back(str.substr(start));
511+
512+
return tokens;
513+
}
514+
515+
std::string string_repeat(const std::string & str, size_t n) {
516+
if (n == 0) {
517+
return "";
518+
}
519+
520+
std::string result;
521+
result.reserve(str.length() * n);
522+
523+
for (size_t i = 0; i < n; ++i) {
524+
result += str;
525+
}
526+
527+
return result;
528+
}
529+
488530
std::string string_from(bool value) {
489531
return value ? "true" : "false";
490532
}

common/common.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,10 @@ std::string string_format(const char * fmt, ...);
431431
std::string string_strip(const std::string & str);
432432
std::string string_get_sortable_timestamp();
433433

434+
std::string string_join(const std::vector<std::string> & values, const std::string & separator);
435+
std::vector<std::string> string_split(const std::string & str, const std::string & delimiter);
436+
std::string string_repeat(const std::string & str, size_t n);
437+
434438
void string_replace_all(std::string & s, const std::string & search, const std::string & replace);
435439

436440
template<class T>

common/json-schema-to-grammar.cpp

Lines changed: 13 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
#include "json-schema-to-grammar.h"
2+
#include "common.h"
3+
24
#include <algorithm>
35
#include <fstream>
46
#include <map>
@@ -11,8 +13,6 @@
1113

1214
using json = nlohmann::ordered_json;
1315

14-
static std::string repeat(const std::string & str, size_t n);
15-
1616
static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "") {
1717
auto has_max = max_items != std::numeric_limits<int>::max();
1818

@@ -125,8 +125,8 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
125125
if (sub_len > 0) {
126126
auto from_sub = from.substr(i + 1);
127127
auto to_sub = to.substr(i + 1);
128-
auto sub_zeros = repeat("0", sub_len);
129-
auto sub_nines = repeat("9", sub_len);
128+
auto sub_zeros = string_repeat("0", sub_len);
129+
auto sub_nines = string_repeat("9", sub_len);
130130

131131
auto to_reached = false;
132132
out << "(";
@@ -185,8 +185,8 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
185185
auto max_digits = max_s.length();
186186

187187
for (auto digits = min_digits; digits < max_digits; digits++) {
188-
uniform_range(min_s, repeat("9", digits));
189-
min_s = "1" + repeat("0", digits);
188+
uniform_range(min_s, string_repeat("9", digits));
189+
min_s = "1" + string_repeat("0", digits);
190190
out << " | ";
191191
}
192192
uniform_range(min_s, max_s);
@@ -315,49 +315,6 @@ std::unordered_map<char, std::string> GRAMMAR_LITERAL_ESCAPES = {
315315
std::unordered_set<char> NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'};
316316
std::unordered_set<char> ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {'^', '$', '.', '[', ']', '(', ')', '|', '{', '}', '*', '+', '?'};
317317

318-
template <typename Iterator>
319-
std::string join(Iterator begin, Iterator end, const std::string & separator) {
320-
std::ostringstream result;
321-
if (begin != end) {
322-
result << *begin;
323-
for (Iterator it = begin + 1; it != end; ++it) {
324-
result << separator << *it;
325-
}
326-
}
327-
return result.str();
328-
}
329-
330-
static std::vector<std::string> split(const std::string & str, const std::string & delimiter) {
331-
std::vector<std::string> tokens;
332-
size_t start = 0;
333-
size_t end = str.find(delimiter);
334-
335-
while (end != std::string::npos) {
336-
tokens.push_back(str.substr(start, end - start));
337-
start = end + delimiter.length();
338-
end = str.find(delimiter, start);
339-
}
340-
341-
tokens.push_back(str.substr(start));
342-
343-
return tokens;
344-
}
345-
346-
static std::string repeat(const std::string & str, size_t n) {
347-
if (n == 0) {
348-
return "";
349-
}
350-
351-
std::string result;
352-
result.reserve(str.length() * n);
353-
354-
for (size_t i = 0; i < n; ++i) {
355-
result += str;
356-
}
357-
358-
return result;
359-
}
360-
361318
static std::string replacePattern(const std::string & input, const std::regex & regex, const std::function<std::string(const std::smatch &)> & replacement) {
362319
std::smatch match;
363320
std::string result;
@@ -416,7 +373,7 @@ class SchemaConverter {
416373
for (size_t i = 0; i < alt_schemas.size(); i++) {
417374
rules.push_back(visit(alt_schemas[i], name + (name.empty() ? "alternative-" : "-") + std::to_string(i)));
418375
}
419-
return join(rules.begin(), rules.end(), " | ");
376+
return string_join(rules, " | ");
420377
}
421378

422379
std::string _visit_pattern(const std::string & pattern, const std::string & name) {
@@ -479,7 +436,7 @@ class SchemaConverter {
479436
for (const auto & item : ret) {
480437
results.push_back(to_rule(item));
481438
}
482-
return std::make_pair(join(results.begin(), results.end(), " "), false);
439+
return std::make_pair(string_join(results, " "), false);
483440
};
484441

485442
while (i < length) {
@@ -537,7 +494,7 @@ class SchemaConverter {
537494
}
538495
curly_brackets += '}';
539496
i++;
540-
auto nums = split(curly_brackets.substr(1, curly_brackets.length() - 2), ",");
497+
auto nums = string_split(curly_brackets.substr(1, curly_brackets.length() - 2), ",");
541498
int min_times = 0;
542499
int max_times = std::numeric_limits<int>::max();
543500
try {
@@ -852,7 +809,7 @@ class SchemaConverter {
852809
return;
853810
}
854811
std::string pointer = ref.substr(ref.find('#') + 1);
855-
std::vector<std::string> tokens = split(pointer, "/");
812+
std::vector<std::string> tokens = string_split(pointer, "/");
856813
for (size_t i = 1; i < tokens.size(); ++i) {
857814
std::string sel = tokens[i];
858815
if (target.is_null() || !target.contains(sel)) {
@@ -903,7 +860,7 @@ class SchemaConverter {
903860
for (const auto & v : schema["enum"]) {
904861
enum_values.push_back(_generate_constant_rule(v));
905862
}
906-
return _add_rule(rule_name, "(" + join(enum_values.begin(), enum_values.end(), " | ") + ") space");
863+
return _add_rule(rule_name, "(" + string_join(enum_values, " | ") + ") space");
907864
} else if ((schema_type.is_null() || schema_type == "object")
908865
&& (schema.contains("properties") ||
909866
(schema.contains("additionalProperties") && schema["additionalProperties"] != true))) {
@@ -1017,10 +974,10 @@ class SchemaConverter {
1017974

1018975
void check_errors() {
1019976
if (!_errors.empty()) {
1020-
throw std::runtime_error("JSON schema conversion failed:\n" + join(_errors.begin(), _errors.end(), "\n"));
977+
throw std::runtime_error("JSON schema conversion failed:\n" + string_join(_errors, "\n"));
1021978
}
1022979
if (!_warnings.empty()) {
1023-
fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", join(_warnings.begin(), _warnings.end(), "; ").c_str());
980+
fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", string_join(_warnings, "; ").c_str());
1024981
}
1025982
}
1026983

common/json-schema-to-grammar.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,6 @@
55
#define JSON_ASSERT GGML_ASSERT
66
#include "json.hpp"
77

8-
template <typename Iterator>
9-
std::string join(Iterator begin, Iterator end, const std::string & separator);
10-
118
std::string json_schema_to_grammar(const nlohmann::ordered_json & schema);
129

1310
struct llama_grammar_builder {

common/tool-call.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,7 @@ common_tool_call_handler common_tool_call_handler_init(
631631
handler.grammar_triggers.push_back("{\n \"");
632632
}
633633

634-
builder.add_rule("root", join(tool_rules.begin(), tool_rules.end(), " | "));
634+
builder.add_rule("root", string_join(tool_rules, " | "));
635635
});
636636
handler.additional_stops.push_back("<|eom_id|>");
637637
handler.prompt = tmpl.apply(messages, actual_tools.empty() ? json() : actual_tools, /* add_generation_prompt= */ true, {
@@ -658,9 +658,9 @@ common_tool_call_handler common_tool_call_handler_init(
658658
handler.grammar_triggers.push_back("\n>>>" + name + "\n");
659659
}
660660
}
661-
auto first_rule = builder.add_rule("first_tool_call", join(first_tool_rules.begin(), first_tool_rules.end(), " | ")) + " space";
661+
auto first_rule = builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space";
662662
if (parallel) {
663-
auto subsequent_rule = builder.add_rule("subsequent_tool_call", join(subsequent_tool_rules.begin(), subsequent_tool_rules.end(), " | ")) + " space";
663+
auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space";
664664
builder.add_rule("root", first_rule + " (" + subsequent_rule + ")*");
665665
} else {
666666
builder.add_rule("root", first_rule);
@@ -690,7 +690,7 @@ common_tool_call_handler common_tool_call_handler_init(
690690
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
691691
}
692692
}
693-
auto tool_call = builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " space";
693+
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
694694
builder.add_rule("root", parallel ? "(" + tool_call + ")+" : tool_call);
695695
if (allow_content) {
696696
handler.grammar_triggers.push_back("<function=");
@@ -721,7 +721,7 @@ common_tool_call_handler common_tool_call_handler_init(
721721
}));
722722
}
723723

724-
auto tool_call = "\"<tool_call>\" space " + builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " \"</tool_call>\" space";
724+
auto tool_call = "\"<tool_call>\" space " + builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " \"</tool_call>\" space";
725725
builder.add_rule("root", parallel ? "(" + tool_call + ")+" : tool_call);
726726
if (allow_content) {
727727
handler.grammar_triggers.push_back("<tool_call>");

0 commit comments

Comments
 (0)