Skip to content

Commit 144a4ce

Browse files
authored
vendor : sync minja (#16500)
* sync minja.hpp Adds Call/EndCall support, used in MiniCPM3 and MiniCPM4-MCP. * remove spurious semicolon * sync from ochafik/minja
1 parent f549b00 commit 144a4ce

File tree

1 file changed

+96
-15
lines changed

1 file changed

+96
-15
lines changed

vendor/minja/minja.hpp

Lines changed: 96 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ inline std::string normalize_newlines(const std::string & s) {
5555
}
5656

5757
/* Values that behave roughly like in Python. */
58-
class Value : public std::enable_shared_from_this<Value> {
58+
class Value {
5959
public:
6060
using CallableType = std::function<Value(const std::shared_ptr<Context> &, ArgumentsValue &)>;
6161
using FilterType = std::function<Value(const std::shared_ptr<Context> &, ArgumentsValue &)>;
@@ -158,12 +158,14 @@ class Value : public std::enable_shared_from_this<Value> {
158158
Value(const json & v) {
159159
if (v.is_object()) {
160160
auto object = std::make_shared<ObjectType>();
161+
object->reserve(v.size());
161162
for (auto it = v.begin(); it != v.end(); ++it) {
162-
(*object)[it.key()] = it.value();
163+
object->emplace_back(it.key(), Value(it.value()));
163164
}
164165
object_ = std::move(object);
165166
} else if (v.is_array()) {
166167
auto array = std::make_shared<ArrayType>();
168+
array->reserve(v.size());
167169
for (const auto& item : v) {
168170
array->push_back(Value(item));
169171
}
@@ -610,7 +612,7 @@ static std::string error_location_suffix(const std::string & source, size_t pos)
610612
return out.str();
611613
}
612614

613-
class Context : public std::enable_shared_from_this<Context> {
615+
class Context {
614616
protected:
615617
Value values_;
616618
std::shared_ptr<Context> parent_;
@@ -706,7 +708,7 @@ enum SpaceHandling { Keep, Strip, StripSpaces, StripNewline };
706708

707709
class TemplateToken {
708710
public:
709-
enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Generation, EndGeneration, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter, Break, Continue };
711+
enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Generation, EndGeneration, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter, Break, Continue, Call, EndCall };
710712

711713
static std::string typeToString(Type t) {
712714
switch (t) {
@@ -729,6 +731,8 @@ class TemplateToken {
729731
case Type::EndGeneration: return "endgeneration";
730732
case Type::Break: return "break";
731733
case Type::Continue: return "continue";
734+
case Type::Call: return "call";
735+
case Type::EndCall: return "endcall";
732736
}
733737
return "Unknown";
734738
}
@@ -846,6 +850,17 @@ struct LoopControlTemplateToken : public TemplateToken {
846850
LoopControlTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, LoopControlType control_type) : TemplateToken(Type::Break, loc, pre, post), control_type(control_type) {}
847851
};
848852

853+
struct CallTemplateToken : public TemplateToken {
854+
std::shared_ptr<Expression> expr;
855+
CallTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, std::shared_ptr<Expression> && e)
856+
: TemplateToken(Type::Call, loc, pre, post), expr(std::move(e)) {}
857+
};
858+
859+
struct EndCallTemplateToken : public TemplateToken {
860+
EndCallTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post)
861+
: TemplateToken(Type::EndCall, loc, pre, post) {}
862+
};
863+
849864
class TemplateNode {
850865
Location location_;
851866
protected:
@@ -1047,36 +1062,48 @@ class MacroNode : public TemplateNode {
10471062
}
10481063
}
10491064
}
1050-
void do_render(std::ostringstream &, const std::shared_ptr<Context> & macro_context) const override {
1065+
void do_render(std::ostringstream &, const std::shared_ptr<Context> & context) const override {
10511066
if (!name) throw std::runtime_error("MacroNode.name is null");
10521067
if (!body) throw std::runtime_error("MacroNode.body is null");
1053-
auto callable = Value::callable([&](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
1054-
auto call_context = macro_context;
1068+
1069+
// Use init-capture to avoid dangling 'this' pointer and circular references
1070+
auto callable = Value::callable([weak_context = std::weak_ptr<Context>(context),
1071+
name = name, params = params, body = body,
1072+
named_param_positions = named_param_positions]
1073+
(const std::shared_ptr<Context> & call_context, ArgumentsValue & args) {
1074+
auto context_locked = weak_context.lock();
1075+
if (!context_locked) throw std::runtime_error("Macro context no longer valid");
1076+
auto execution_context = Context::make(Value::object(), context_locked);
1077+
1078+
if (call_context->contains("caller")) {
1079+
execution_context->set("caller", call_context->get("caller"));
1080+
}
1081+
10551082
std::vector<bool> param_set(params.size(), false);
10561083
for (size_t i = 0, n = args.args.size(); i < n; i++) {
10571084
auto & arg = args.args[i];
10581085
if (i >= params.size()) throw std::runtime_error("Too many positional arguments for macro " + name->get_name());
10591086
param_set[i] = true;
1060-
auto & param_name = params[i].first;
1061-
call_context->set(param_name, arg);
1087+
const auto & param_name = params[i].first;
1088+
execution_context->set(param_name, arg);
10621089
}
10631090
for (auto & [arg_name, value] : args.kwargs) {
10641091
auto it = named_param_positions.find(arg_name);
10651092
if (it == named_param_positions.end()) throw std::runtime_error("Unknown parameter name for macro " + name->get_name() + ": " + arg_name);
10661093

1067-
call_context->set(arg_name, value);
1094+
execution_context->set(arg_name, value);
10681095
param_set[it->second] = true;
10691096
}
10701097
// Set default values for parameters that were not passed
10711098
for (size_t i = 0, n = params.size(); i < n; i++) {
10721099
if (!param_set[i] && params[i].second != nullptr) {
1073-
auto val = params[i].second->evaluate(context);
1074-
call_context->set(params[i].first, val);
1100+
auto val = params[i].second->evaluate(call_context);
1101+
execution_context->set(params[i].first, val);
10751102
}
10761103
}
1077-
return body->render(call_context);
1104+
return body->render(execution_context);
10781105
});
1079-
macro_context->set(name->get_name(), callable);
1106+
context->set(name->get_name(), callable);
10801107
}
10811108
};
10821109

@@ -1611,6 +1638,44 @@ class CallExpr : public Expression {
16111638
}
16121639
};
16131640

1641+
class CallNode : public TemplateNode {
1642+
std::shared_ptr<Expression> expr;
1643+
std::shared_ptr<TemplateNode> body;
1644+
1645+
public:
1646+
CallNode(const Location & loc, std::shared_ptr<Expression> && e, std::shared_ptr<TemplateNode> && b)
1647+
: TemplateNode(loc), expr(std::move(e)), body(std::move(b)) {}
1648+
1649+
void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
1650+
if (!expr) throw std::runtime_error("CallNode.expr is null");
1651+
if (!body) throw std::runtime_error("CallNode.body is null");
1652+
1653+
// Use init-capture to avoid dangling 'this' pointer and circular references
1654+
auto caller = Value::callable([weak_context = std::weak_ptr<Context>(context), body=body]
1655+
(const std::shared_ptr<Context> &, ArgumentsValue &) -> Value {
1656+
auto context_locked = weak_context.lock();
1657+
if (!context_locked) throw std::runtime_error("Caller context no longer valid");
1658+
return Value(body->render(context_locked));
1659+
});
1660+
1661+
context->set("caller", caller);
1662+
1663+
auto call_expr = dynamic_cast<CallExpr*>(expr.get());
1664+
if (!call_expr) {
1665+
throw std::runtime_error("Invalid call block syntax - expected function call");
1666+
}
1667+
1668+
Value function = call_expr->object->evaluate(context);
1669+
if (!function.is_callable()) {
1670+
throw std::runtime_error("Call target must be callable: " + function.dump());
1671+
}
1672+
ArgumentsValue args = call_expr->args.evaluate(context);
1673+
1674+
Value result = function.call(context, args);
1675+
out << result.to_str();
1676+
}
1677+
};
1678+
16141679
class FilterExpr : public Expression {
16151680
std::vector<std::shared_ptr<Expression>> parts;
16161681
public:
@@ -2320,7 +2385,7 @@ class Parser {
23202385
static std::regex comment_tok(R"(\{#([-~]?)([\s\S]*?)([-~]?)#\})");
23212386
static std::regex expr_open_regex(R"(\{\{([-~])?)");
23222387
static std::regex block_open_regex(R"(^\{%([-~])?\s*)");
2323-
static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter|break|continue)\b)");
2388+
static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter|break|continue|call|endcall)\b)");
23242389
static std::regex non_text_open_regex(R"(\{\{|\{%|\{#)");
23252390
static std::regex expr_close_regex(R"(\s*([-~])?\}\})");
23262391
static std::regex block_close_regex(R"(\s*([-~])?%\})");
@@ -2443,6 +2508,15 @@ class Parser {
24432508
} else if (keyword == "endmacro") {
24442509
auto post_space = parseBlockClose();
24452510
tokens.push_back(std::make_unique<EndMacroTemplateToken>(location, pre_space, post_space));
2511+
} else if (keyword == "call") {
2512+
auto expr = parseExpression();
2513+
if (!expr) throw std::runtime_error("Expected expression in call block");
2514+
2515+
auto post_space = parseBlockClose();
2516+
tokens.push_back(std::make_unique<CallTemplateToken>(location, pre_space, post_space, std::move(expr)));
2517+
} else if (keyword == "endcall") {
2518+
auto post_space = parseBlockClose();
2519+
tokens.push_back(std::make_unique<EndCallTemplateToken>(location, pre_space, post_space));
24462520
} else if (keyword == "filter") {
24472521
auto filter = parseExpression();
24482522
if (!filter) throw std::runtime_error("Expected expression in filter block");
@@ -2575,6 +2649,12 @@ class Parser {
25752649
throw unterminated(**start);
25762650
}
25772651
children.emplace_back(std::make_shared<MacroNode>(token->location, std::move(macro_token->name), std::move(macro_token->params), std::move(body)));
2652+
} else if (auto call_token = dynamic_cast<CallTemplateToken*>(token.get())) {
2653+
auto body = parseTemplate(begin, it, end);
2654+
if (it == end || (*(it++))->type != TemplateToken::Type::EndCall) {
2655+
throw unterminated(**start);
2656+
}
2657+
children.emplace_back(std::make_shared<CallNode>(token->location, std::move(call_token->expr), std::move(body)));
25782658
} else if (auto filter_token = dynamic_cast<FilterTemplateToken*>(token.get())) {
25792659
auto body = parseTemplate(begin, it, end);
25802660
if (it == end || (*(it++))->type != TemplateToken::Type::EndFilter) {
@@ -2588,6 +2668,7 @@ class Parser {
25882668
} else if (dynamic_cast<EndForTemplateToken*>(token.get())
25892669
|| dynamic_cast<EndSetTemplateToken*>(token.get())
25902670
|| dynamic_cast<EndMacroTemplateToken*>(token.get())
2671+
|| dynamic_cast<EndCallTemplateToken*>(token.get())
25912672
|| dynamic_cast<EndFilterTemplateToken*>(token.get())
25922673
|| dynamic_cast<EndIfTemplateToken*>(token.get())
25932674
|| dynamic_cast<ElseTemplateToken*>(token.get())

0 commit comments

Comments
 (0)