Skip to content

Commit 6c7e9a5

Browse files
ochafikCISC
andauthored
vendor: sync minja (#15161)
* vendor: sync minja * Update minja.hpp * Apply suggestions from code review Co-authored-by: Sigbjørn Skjæret <[email protected]> --------- Co-authored-by: Sigbjørn Skjæret <[email protected]>
1 parent 1425f58 commit 6c7e9a5

File tree

2 files changed

+44
-20
lines changed

2 files changed

+44
-20
lines changed

vendor/minja/chat-template.hpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,15 @@ class chat_template {
162162
}), false);
163163
caps_.supports_tools = contains(out, "some_tool");
164164

165-
auto out_empty = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", ""}}}), {}, false);
166-
auto out_null = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", nullptr}}}), {}, false);
165+
const auto render_with_content = [&](const json & content) {
166+
const json assistant_msg {{"role", "assistant"}, {"content", content}};
167+
// Render two assistant messages as some templates like QwQ-32B are handling
168+
// the content differently depending on whether it's the last message or not
169+
// (to remove the <think> tag in all but the last message).
170+
return try_raw_render(json::array({dummy_user_msg, assistant_msg, dummy_user_msg, assistant_msg}), {}, false);
171+
};
172+
auto out_empty = render_with_content("");
173+
auto out_null = render_with_content(json());
167174
caps_.requires_non_null_content = contains(out_empty, user_needle) && !contains(out_null, user_needle);
168175

169176
json j_null;
@@ -191,12 +198,12 @@ class chat_template {
191198
dummy_user_msg,
192199
make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj.dump())})),
193200
}), {}, false);
194-
auto tool_call_renders_str_arguments = contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':");
201+
auto tool_call_renders_str_arguments = contains(out, "<parameter=argument_needle>") || contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':");
195202
out = try_raw_render(json::array({
196203
dummy_user_msg,
197204
make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj)})),
198205
}), {}, false);
199-
auto tool_call_renders_obj_arguments = contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':");
206+
auto tool_call_renders_obj_arguments = contains(out, "<parameter=argument_needle>") || contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':");
200207

201208
caps_.supports_tool_calls = tool_call_renders_str_arguments || tool_call_renders_obj_arguments;
202209
caps_.requires_object_arguments = !tool_call_renders_str_arguments && tool_call_renders_obj_arguments;

vendor/minja/minja.hpp

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1291,6 +1291,12 @@ class UnaryOpExpr : public Expression {
12911291
}
12921292
};
12931293

1294+
static bool in(const Value & value, const Value & container) {
1295+
return (((container.is_array() || container.is_object()) && container.contains(value)) ||
1296+
(value.is_string() && container.is_string() &&
1297+
container.to_str().find(value.to_str()) != std::string::npos));
1298+
}
1299+
12941300
class BinaryOpExpr : public Expression {
12951301
public:
12961302
enum class Op { StrConcat, Add, Sub, Mul, MulMul, Div, DivDiv, Mod, Eq, Ne, Lt, Gt, Le, Ge, And, Or, In, NotIn, Is, IsNot };
@@ -1355,13 +1361,8 @@ class BinaryOpExpr : public Expression {
13551361
case Op::Gt: return l > r;
13561362
case Op::Le: return l <= r;
13571363
case Op::Ge: return l >= r;
1358-
case Op::In: return (((r.is_array() || r.is_object()) && r.contains(l)) ||
1359-
(l.is_string() && r.is_string() &&
1360-
r.to_str().find(l.to_str()) != std::string::npos));
1361-
case Op::NotIn:
1362-
return !(((r.is_array() || r.is_object()) && r.contains(l)) ||
1363-
(l.is_string() && r.is_string() &&
1364-
r.to_str().find(l.to_str()) != std::string::npos));
1364+
case Op::In: return in(l, r);
1365+
case Op::NotIn: return !in(l, r);
13651366
default: break;
13661367
}
13671368
throw std::runtime_error("Unknown binary operator");
@@ -1500,6 +1501,13 @@ class MethodCallExpr : public Expression {
15001501
} else if (method->get_name() == "pop") {
15011502
vargs.expectArgs("pop method", {1, 1}, {0, 0});
15021503
return obj.pop(vargs.args[0]);
1504+
} else if (method->get_name() == "keys") {
1505+
vargs.expectArgs("keys method", {0, 0}, {0, 0});
1506+
auto result = Value::array();
1507+
for (const auto& key : obj.keys()) {
1508+
result.push_back(Value(key));
1509+
}
1510+
return result;
15031511
} else if (method->get_name() == "get") {
15041512
vargs.expectArgs("get method", {1, 2}, {0, 0});
15051513
auto key = vargs.args[0];
@@ -1541,6 +1549,16 @@ class MethodCallExpr : public Expression {
15411549
} else if (method->get_name() == "capitalize") {
15421550
vargs.expectArgs("capitalize method", {0, 0}, {0, 0});
15431551
return Value(capitalize(str));
1552+
} else if (method->get_name() == "upper") {
1553+
vargs.expectArgs("upper method", {0, 0}, {0, 0});
1554+
auto result = str;
1555+
std::transform(result.begin(), result.end(), result.begin(), ::toupper);
1556+
return Value(result);
1557+
} else if (method->get_name() == "lower") {
1558+
vargs.expectArgs("lower method", {0, 0}, {0, 0});
1559+
auto result = str;
1560+
std::transform(result.begin(), result.end(), result.begin(), ::tolower);
1561+
return Value(result);
15441562
} else if (method->get_name() == "endswith") {
15451563
vargs.expectArgs("endswith method", {1, 1}, {0, 0});
15461564
auto suffix = vargs.args[0].get<std::string>();
@@ -2646,15 +2664,11 @@ inline std::shared_ptr<Context> Context::builtins() {
26462664
auto items = Value::array();
26472665
if (args.contains("object")) {
26482666
auto & obj = args.at("object");
2649-
if (obj.is_string()) {
2650-
auto json_obj = json::parse(obj.get<std::string>());
2651-
for (const auto & kv : json_obj.items()) {
2652-
items.push_back(Value::array({kv.key(), kv.value()}));
2653-
}
2654-
} else if (!obj.is_null()) {
2655-
for (auto & key : obj.keys()) {
2656-
items.push_back(Value::array({key, obj.at(key)}));
2657-
}
2667+
if (!obj.is_object()) {
2668+
throw std::runtime_error("Can only get item pairs from a mapping");
2669+
}
2670+
for (auto & key : obj.keys()) {
2671+
items.push_back(Value::array({key, obj.at(key)}));
26582672
}
26592673
}
26602674
return items;
@@ -2782,6 +2796,9 @@ inline std::shared_ptr<Context> Context::builtins() {
27822796
if (!items.is_array()) throw std::runtime_error("object is not iterable");
27832797
return items;
27842798
}));
2799+
globals.set("in", simple_function("in", { "item", "items" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
2800+
return in(args.at("item"), args.at("items"));
2801+
}));
27852802
globals.set("unique", simple_function("unique", { "items" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
27862803
auto & items = args.at("items");
27872804
if (!items.is_array()) throw std::runtime_error("object is not iterable");

0 commit comments

Comments
 (0)