Skip to content

Commit 5ff6ec7

Browse files
authored
Support call/endcall blocks (google#81)
* Add call/endcall support * Add syntax tests for call blocks * Add call blocks to supported features in readme * Add openbmb/MiniCPM3-4B to test models * Remove non-existent model * Add tests for unterminated call and macro
1 parent e9a9bb2 commit 5ff6ec7

File tree

4 files changed

+136
-11
lines changed

4 files changed

+136
-11
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,11 @@ Minja supports the following subset of the [Jinja2/3 template syntax](https://ji
104104
- Full expression syntax
105105
- Statements `{{% … %}}`, variable sections `{{ … }}`, and comments `{# … #}` with pre/post space elision `{%- … -%}` / `{{- … -}}` / `{#- … -#}`
106106
- `if` / `elif` / `else` / `endif`
107-
- `for` (`recursive`) (`if`) / `else` / `endfor` w/ `loop.*` (including `loop.cycle`) and destructuring
107+
- `for` (`recursive`) (`if`) / `else` / `endfor` w/ `loop.*` (including `loop.cycle`) and destructuring)
108108
- `break`, `continue` (aka [loop controls extensions](https://github.com/google/minja/pull/39))
109109
- `set` w/ namespaces & destructuring
110110
- `macro` / `endmacro`
111+
- `call` / `endcall` - for calling macro (w/ macro arguments and `caller()` syntax) and passing a macro to another macro (w/o passing arguments back to the call block)
111112
- `filter` / `endfilter`
112113
- Extensible filters collection: `count`, `dictsort`, `equalto`, `e` / `escape`, `items`, `join`, `joiner`, `namespace`, `raise_exception`, `range`, `reject` / `rejectattr` / `select` / `selectattr`, `tojson`, `trim`
113114

include/minja/minja.hpp

Lines changed: 77 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -706,7 +706,7 @@ enum SpaceHandling { Keep, Strip, StripSpaces, StripNewline };
706706

707707
class TemplateToken {
708708
public:
709-
enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Generation, EndGeneration, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter, Break, Continue };
709+
enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Generation, EndGeneration, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter, Break, Continue, Call, EndCall };
710710

711711
static std::string typeToString(Type t) {
712712
switch (t) {
@@ -729,6 +729,8 @@ class TemplateToken {
729729
case Type::EndGeneration: return "endgeneration";
730730
case Type::Break: return "break";
731731
case Type::Continue: return "continue";
732+
case Type::Call: return "call";
733+
case Type::EndCall: return "endcall";
732734
}
733735
return "Unknown";
734736
}
@@ -846,6 +848,17 @@ struct LoopControlTemplateToken : public TemplateToken {
846848
LoopControlTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, LoopControlType control_type) : TemplateToken(Type::Break, loc, pre, post), control_type(control_type) {}
847849
};
848850

851+
struct CallTemplateToken : public TemplateToken {
852+
std::shared_ptr<Expression> expr;
853+
CallTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, std::shared_ptr<Expression> && e)
854+
: TemplateToken(Type::Call, loc, pre, post), expr(std::move(e)) {}
855+
};
856+
857+
struct EndCallTemplateToken : public TemplateToken {
858+
EndCallTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post)
859+
: TemplateToken(Type::EndCall, loc, pre, post) {}
860+
};
861+
849862
class TemplateNode {
850863
Location location_;
851864
protected:
@@ -1050,31 +1063,36 @@ class MacroNode : public TemplateNode {
10501063
void do_render(std::ostringstream &, const std::shared_ptr<Context> & macro_context) const override {
10511064
if (!name) throw std::runtime_error("MacroNode.name is null");
10521065
if (!body) throw std::runtime_error("MacroNode.body is null");
1053-
auto callable = Value::callable([&](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
1054-
auto call_context = macro_context;
1066+
auto callable = Value::callable([this, macro_context](const std::shared_ptr<Context> & call_context, ArgumentsValue & args) {
1067+
auto execution_context = Context::make(Value::object(), macro_context);
1068+
1069+
if (call_context->contains("caller")) {
1070+
execution_context->set("caller", call_context->get("caller"));
1071+
}
1072+
10551073
std::vector<bool> param_set(params.size(), false);
10561074
for (size_t i = 0, n = args.args.size(); i < n; i++) {
10571075
auto & arg = args.args[i];
10581076
if (i >= params.size()) throw std::runtime_error("Too many positional arguments for macro " + name->get_name());
10591077
param_set[i] = true;
10601078
auto & param_name = params[i].first;
1061-
call_context->set(param_name, arg);
1079+
execution_context->set(param_name, arg);
10621080
}
10631081
for (auto & [arg_name, value] : args.kwargs) {
10641082
auto it = named_param_positions.find(arg_name);
10651083
if (it == named_param_positions.end()) throw std::runtime_error("Unknown parameter name for macro " + name->get_name() + ": " + arg_name);
10661084

1067-
call_context->set(arg_name, value);
1085+
execution_context->set(arg_name, value);
10681086
param_set[it->second] = true;
10691087
}
10701088
// Set default values for parameters that were not passed
10711089
for (size_t i = 0, n = params.size(); i < n; i++) {
10721090
if (!param_set[i] && params[i].second != nullptr) {
1073-
auto val = params[i].second->evaluate(context);
1074-
call_context->set(params[i].first, val);
1091+
auto val = params[i].second->evaluate(call_context);
1092+
execution_context->set(params[i].first, val);
10751093
}
10761094
}
1077-
return body->render(call_context);
1095+
return body->render(execution_context);
10781096
});
10791097
macro_context->set(name->get_name(), callable);
10801098
}
@@ -1611,6 +1629,40 @@ class CallExpr : public Expression {
16111629
}
16121630
};
16131631

1632+
class CallNode : public TemplateNode {
1633+
std::shared_ptr<Expression> expr;
1634+
std::shared_ptr<TemplateNode> body;
1635+
1636+
public:
1637+
CallNode(const Location & loc, std::shared_ptr<Expression> && e, std::shared_ptr<TemplateNode> && b)
1638+
: TemplateNode(loc), expr(std::move(e)), body(std::move(b)) {}
1639+
1640+
void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
1641+
if (!expr) throw std::runtime_error("CallNode.expr is null");
1642+
if (!body) throw std::runtime_error("CallNode.body is null");
1643+
1644+
auto caller = Value::callable([this, context](const std::shared_ptr<Context> &, ArgumentsValue &) -> Value {
1645+
return Value(body->render(context));
1646+
});
1647+
1648+
context->set("caller", caller);
1649+
1650+
auto call_expr = dynamic_cast<CallExpr*>(expr.get());
1651+
if (!call_expr) {
1652+
throw std::runtime_error("Invalid call block syntax - expected function call");
1653+
}
1654+
1655+
Value function = call_expr->object->evaluate(context);
1656+
if (!function.is_callable()) {
1657+
throw std::runtime_error("Call target must be callable: " + function.dump());
1658+
}
1659+
ArgumentsValue args = call_expr->args.evaluate(context);
1660+
1661+
Value result = function.call(context, args);
1662+
out << result.to_str();
1663+
}
1664+
};
1665+
16141666
class FilterExpr : public Expression {
16151667
std::vector<std::shared_ptr<Expression>> parts;
16161668
public:
@@ -2320,7 +2372,7 @@ class Parser {
23202372
static std::regex comment_tok(R"(\{#([-~]?)([\s\S]*?)([-~]?)#\})");
23212373
static std::regex expr_open_regex(R"(\{\{([-~])?)");
23222374
static std::regex block_open_regex(R"(^\{%([-~])?\s*)");
2323-
static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter|break|continue)\b)");
2375+
static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter|break|continue|call|endcall)\b)");
23242376
static std::regex non_text_open_regex(R"(\{\{|\{%|\{#)");
23252377
static std::regex expr_close_regex(R"(\s*([-~])?\}\})");
23262378
static std::regex block_close_regex(R"(\s*([-~])?%\})");
@@ -2443,6 +2495,15 @@ class Parser {
24432495
} else if (keyword == "endmacro") {
24442496
auto post_space = parseBlockClose();
24452497
tokens.push_back(std::make_unique<EndMacroTemplateToken>(location, pre_space, post_space));
2498+
} else if (keyword == "call") {
2499+
auto expr = parseExpression();
2500+
if (!expr) throw std::runtime_error("Expected expression in call block");
2501+
2502+
auto post_space = parseBlockClose();
2503+
tokens.push_back(std::make_unique<CallTemplateToken>(location, pre_space, post_space, std::move(expr)));
2504+
} else if (keyword == "endcall") {
2505+
auto post_space = parseBlockClose();
2506+
tokens.push_back(std::make_unique<EndCallTemplateToken>(location, pre_space, post_space));
24462507
} else if (keyword == "filter") {
24472508
auto filter = parseExpression();
24482509
if (!filter) throw std::runtime_error("Expected expression in filter block");
@@ -2575,6 +2636,12 @@ class Parser {
25752636
throw unterminated(**start);
25762637
}
25772638
children.emplace_back(std::make_shared<MacroNode>(token->location, std::move(macro_token->name), std::move(macro_token->params), std::move(body)));
2639+
} else if (auto call_token = dynamic_cast<CallTemplateToken*>(token.get())) {
2640+
auto body = parseTemplate(begin, it, end);
2641+
if (it == end || (*(it++))->type != TemplateToken::Type::EndCall) {
2642+
throw unterminated(**start);
2643+
}
2644+
children.emplace_back(std::make_shared<CallNode>(token->location, std::move(call_token->expr), std::move(body)));
25782645
} else if (auto filter_token = dynamic_cast<FilterTemplateToken*>(token.get())) {
25792646
auto body = parseTemplate(begin, it, end);
25802647
if (it == end || (*(it++))->type != TemplateToken::Type::EndFilter) {
@@ -2588,6 +2655,7 @@ class Parser {
25882655
} else if (dynamic_cast<EndForTemplateToken*>(token.get())
25892656
|| dynamic_cast<EndSetTemplateToken*>(token.get())
25902657
|| dynamic_cast<EndMacroTemplateToken*>(token.get())
2658+
|| dynamic_cast<EndCallTemplateToken*>(token.get())
25912659
|| dynamic_cast<EndFilterTemplateToken*>(token.get())
25922660
|| dynamic_cast<EndIfTemplateToken*>(token.get())
25932661
|| dynamic_cast<ElseTemplateToken*>(token.get())

tests/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ set(MODEL_IDS
226226
OnlyCheeini/greesychat-turbo
227227
onnx-community/DeepSeek-R1-Distill-Qwen-1.5B-ONNX
228228
open-thoughts/OpenThinker-7B
229+
openbmb/MiniCPM3-4B
229230
openchat/openchat-3.5-0106
230231
Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2
231232
OrionStarAI/Orion-14B-Chat
@@ -261,7 +262,6 @@ set(MODEL_IDS
261262
prithivMLmods/Qwen2.5-7B-DeepSeek-R1-1M
262263
prithivMLmods/QwQ-Math-IO-500M
263264
prithivMLmods/Triangulum-v2-10B
264-
qingy2024/Falcon3-2x10B-MoE-Instruct
265265
Qwen/QVQ-72B-Preview
266266
Qwen/Qwen1.5-7B-Chat
267267
Qwen/Qwen2-7B-Instruct

tests/test-syntax.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,54 @@ TEST(SyntaxTest, SimpleCases) {
429429
{%- endmacro -%}
430430
{{- foo() }} {{ foo() -}})", {}, {}));
431431

432+
EXPECT_EQ(
433+
"x,x",
434+
render(R"(
435+
{%- macro test() -%}{{ caller() }},{{ caller() }}{%- endmacro -%}
436+
{%- call test() -%}x{%- endcall -%}
437+
)", {}, {}));
438+
439+
EXPECT_EQ(
440+
"Outer[Inner(X)]",
441+
render(R"(
442+
{%- macro outer() -%}Outer[{{ caller() }}]{%- endmacro -%}
443+
{%- macro inner() -%}Inner({{ caller() }}){%- endmacro -%}
444+
{%- call outer() -%}{%- call inner() -%}X{%- endcall -%}{%- endcall -%}
445+
)", {}, {}));
446+
447+
EXPECT_EQ(
448+
"<ul><li>A</li><li>B</li></ul>",
449+
render(R"(
450+
{%- macro test(prefix, suffix) -%}{{ prefix }}{{ caller() }}{{ suffix }}{%- endmacro -%}
451+
{%- set items = ["a", "b"] -%}
452+
{%- call test("<ul>", "</ul>") -%}
453+
{%- for item in items -%}
454+
<li>{{ item | upper }}</li>
455+
{%- endfor -%}
456+
{%- endcall -%}
457+
)", {}, {}));
458+
459+
EXPECT_EQ(
460+
"\\n\\nclass A:\\n b: 1\\n c: 2\\n",
461+
render(R"(
462+
{%- macro recursive(obj) -%}
463+
{%- set ns = namespace(content = caller()) -%}
464+
{%- for key, value in obj.items() %}
465+
{%- if value is mapping %}
466+
{%- call recursive(value) -%}
467+
{{ '\\n\\nclass ' + key.title() + ':\\n' }}
468+
{%- endcall -%}
469+
{%- else -%}
470+
{%- set ns.content = ns.content + ' ' + key + ': ' + value + '\\n' -%}
471+
{%- endif -%}
472+
{%- endfor -%}
473+
{{ ns.content }}
474+
{%- endmacro -%}
475+
476+
{%- call recursive({"a": {"b": "1", "c": "2"}}) -%}
477+
{%- endcall -%}
478+
)", {}, {}));
479+
432480
if (!getenv("USE_JINJA2")) {
433481
EXPECT_EQ(
434482
"Foo",
@@ -576,6 +624,8 @@ TEST(SyntaxTest, SimpleCases) {
576624
EXPECT_THAT([]() { render("{% elif 1 %}", {}, {}); }, ThrowsWithSubstr("Unexpected elif"));
577625
EXPECT_THAT([]() { render("{% endfor %}", {}, {}); }, ThrowsWithSubstr("Unexpected endfor"));
578626
EXPECT_THAT([]() { render("{% endfilter %}", {}, {}); }, ThrowsWithSubstr("Unexpected endfilter"));
627+
EXPECT_THAT([]() { render("{% endmacro %}", {}, {}); }, ThrowsWithSubstr("Unexpected endmacro"));
628+
EXPECT_THAT([]() { render("{% endcall %}", {}, {}); }, ThrowsWithSubstr("Unexpected endcall"));
579629

580630
EXPECT_THAT([]() { render("{% if 1 %}", {}, {}); }, ThrowsWithSubstr("Unterminated if"));
581631
EXPECT_THAT([]() { render("{% for x in 1 %}", {}, {}); }, ThrowsWithSubstr("Unterminated for"));
@@ -584,6 +634,12 @@ TEST(SyntaxTest, SimpleCases) {
584634
EXPECT_THAT([]() { render("{% if 1 %}{% else %}{% elif 1 %}{% endif %}", {}, {}); }, ThrowsWithSubstr("Unterminated if"));
585635
EXPECT_THAT([]() { render("{% filter trim %}", {}, {}); }, ThrowsWithSubstr("Unterminated filter"));
586636
EXPECT_THAT([]() { render("{# ", {}, {}); }, ThrowsWithSubstr("Missing end of comment tag"));
637+
EXPECT_THAT([]() { render("{% macro test() %}", {}, {}); }, ThrowsWithSubstr("Unterminated macro"));
638+
EXPECT_THAT([]() { render("{% call test %}", {}, {}); }, ThrowsWithSubstr("Unterminated call"));
639+
640+
EXPECT_THAT([]() {
641+
render("{%- macro test() -%}content{%- endmacro -%}{%- call test -%}caller_content{%- endcall -%}", {}, {});
642+
}, ThrowsWithSubstr("Invalid call block syntax - expected function call"));
587643
}
588644

589645
EXPECT_EQ(

0 commit comments

Comments
 (0)