Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,11 @@ Minja supports the following subset of the [Jinja2/3 template syntax](https://ji
- Full expression syntax
- Statements `{{% … %}}`, variable sections `{{ … }}`, and comments `{# … #}` with pre/post space elision `{%- … -%}` / `{{- … -}}` / `{#- … -#}`
- `if` / `elif` / `else` / `endif`
- `for` (`recursive`) (`if`) / `else` / `endfor` w/ `loop.*` (including `loop.cycle`) and destructuring
- `for` (`recursive`) (`if`) / `else` / `endfor` w/ `loop.*` (including `loop.cycle`) and destructuring)
- `break`, `continue` (aka [loop controls extensions](https://github.com/google/minja/pull/39))
- `set` w/ namespaces & destructuring
- `macro` / `endmacro`
- `call` / `endcall` - for calling macro (w/ macro arguments and `caller()` syntax) and passing a macro to another macro (w/o passing arguments back to the call block)
- `filter` / `endfilter`
- Extensible filters collection: `count`, `dictsort`, `equalto`, `e` / `escape`, `items`, `join`, `joiner`, `namespace`, `raise_exception`, `range`, `reject` / `rejectattr` / `select` / `selectattr`, `tojson`, `trim`

Expand Down
86 changes: 77 additions & 9 deletions include/minja/minja.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,7 @@ enum SpaceHandling { Keep, Strip, StripSpaces, StripNewline };

class TemplateToken {
public:
enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Generation, EndGeneration, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter, Break, Continue };
enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Generation, EndGeneration, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter, Break, Continue, Call, EndCall };

static std::string typeToString(Type t) {
switch (t) {
Expand All @@ -729,6 +729,8 @@ class TemplateToken {
case Type::EndGeneration: return "endgeneration";
case Type::Break: return "break";
case Type::Continue: return "continue";
case Type::Call: return "call";
case Type::EndCall: return "endcall";
}
return "Unknown";
}
Expand Down Expand Up @@ -846,6 +848,17 @@ struct LoopControlTemplateToken : public TemplateToken {
LoopControlTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, LoopControlType control_type) : TemplateToken(Type::Break, loc, pre, post), control_type(control_type) {}
};

struct CallTemplateToken : public TemplateToken {
std::shared_ptr<Expression> expr;
CallTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, std::shared_ptr<Expression> && e)
: TemplateToken(Type::Call, loc, pre, post), expr(std::move(e)) {}
};

struct EndCallTemplateToken : public TemplateToken {
EndCallTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post)
: TemplateToken(Type::EndCall, loc, pre, post) {}
};

class TemplateNode {
Location location_;
protected:
Expand Down Expand Up @@ -1050,31 +1063,36 @@ class MacroNode : public TemplateNode {
void do_render(std::ostringstream &, const std::shared_ptr<Context> & macro_context) const override {
if (!name) throw std::runtime_error("MacroNode.name is null");
if (!body) throw std::runtime_error("MacroNode.body is null");
auto callable = Value::callable([&](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
auto call_context = macro_context;
auto callable = Value::callable([this, macro_context](const std::shared_ptr<Context> & call_context, ArgumentsValue & args) {
auto execution_context = Context::make(Value::object(), macro_context);

if (call_context->contains("caller")) {
execution_context->set("caller", call_context->get("caller"));
}

std::vector<bool> param_set(params.size(), false);
for (size_t i = 0, n = args.args.size(); i < n; i++) {
auto & arg = args.args[i];
if (i >= params.size()) throw std::runtime_error("Too many positional arguments for macro " + name->get_name());
param_set[i] = true;
auto & param_name = params[i].first;
call_context->set(param_name, arg);
execution_context->set(param_name, arg);
}
for (auto & [arg_name, value] : args.kwargs) {
auto it = named_param_positions.find(arg_name);
if (it == named_param_positions.end()) throw std::runtime_error("Unknown parameter name for macro " + name->get_name() + ": " + arg_name);

call_context->set(arg_name, value);
execution_context->set(arg_name, value);
param_set[it->second] = true;
}
// Set default values for parameters that were not passed
for (size_t i = 0, n = params.size(); i < n; i++) {
if (!param_set[i] && params[i].second != nullptr) {
auto val = params[i].second->evaluate(context);
call_context->set(params[i].first, val);
auto val = params[i].second->evaluate(call_context);
execution_context->set(params[i].first, val);
}
}
return body->render(call_context);
return body->render(execution_context);
});
macro_context->set(name->get_name(), callable);
}
Expand Down Expand Up @@ -1611,6 +1629,40 @@ class CallExpr : public Expression {
}
};

class CallNode : public TemplateNode {
std::shared_ptr<Expression> expr;
std::shared_ptr<TemplateNode> body;

public:
CallNode(const Location & loc, std::shared_ptr<Expression> && e, std::shared_ptr<TemplateNode> && b)
: TemplateNode(loc), expr(std::move(e)), body(std::move(b)) {}

void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
if (!expr) throw std::runtime_error("CallNode.expr is null");
if (!body) throw std::runtime_error("CallNode.body is null");

auto caller = Value::callable([this, context](const std::shared_ptr<Context> &, ArgumentsValue &) -> Value {
return Value(body->render(context));
});

context->set("caller", caller);

auto call_expr = dynamic_cast<CallExpr*>(expr.get());
if (!call_expr) {
throw std::runtime_error("Invalid call block syntax - expected function call");
}

Value function = call_expr->object->evaluate(context);
if (!function.is_callable()) {
throw std::runtime_error("Call target must be callable: " + function.dump());
}
ArgumentsValue args = call_expr->args.evaluate(context);

Value result = function.call(context, args);
out << result.to_str();
}
};

class FilterExpr : public Expression {
std::vector<std::shared_ptr<Expression>> parts;
public:
Expand Down Expand Up @@ -2320,7 +2372,7 @@ class Parser {
static std::regex comment_tok(R"(\{#([-~]?)([\s\S]*?)([-~]?)#\})");
static std::regex expr_open_regex(R"(\{\{([-~])?)");
static std::regex block_open_regex(R"(^\{%([-~])?\s*)");
static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter|break|continue)\b)");
static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter|break|continue|call|endcall)\b)");
static std::regex non_text_open_regex(R"(\{\{|\{%|\{#)");
static std::regex expr_close_regex(R"(\s*([-~])?\}\})");
static std::regex block_close_regex(R"(\s*([-~])?%\})");
Expand Down Expand Up @@ -2443,6 +2495,15 @@ class Parser {
} else if (keyword == "endmacro") {
auto post_space = parseBlockClose();
tokens.push_back(std::make_unique<EndMacroTemplateToken>(location, pre_space, post_space));
} else if (keyword == "call") {
auto expr = parseExpression();
if (!expr) throw std::runtime_error("Expected expression in call block");

auto post_space = parseBlockClose();
tokens.push_back(std::make_unique<CallTemplateToken>(location, pre_space, post_space, std::move(expr)));
} else if (keyword == "endcall") {
auto post_space = parseBlockClose();
tokens.push_back(std::make_unique<EndCallTemplateToken>(location, pre_space, post_space));
} else if (keyword == "filter") {
auto filter = parseExpression();
if (!filter) throw std::runtime_error("Expected expression in filter block");
Expand Down Expand Up @@ -2575,6 +2636,12 @@ class Parser {
throw unterminated(**start);
}
children.emplace_back(std::make_shared<MacroNode>(token->location, std::move(macro_token->name), std::move(macro_token->params), std::move(body)));
} else if (auto call_token = dynamic_cast<CallTemplateToken*>(token.get())) {
auto body = parseTemplate(begin, it, end);
if (it == end || (*(it++))->type != TemplateToken::Type::EndCall) {
throw unterminated(**start);
}
children.emplace_back(std::make_shared<CallNode>(token->location, std::move(call_token->expr), std::move(body)));
} else if (auto filter_token = dynamic_cast<FilterTemplateToken*>(token.get())) {
auto body = parseTemplate(begin, it, end);
if (it == end || (*(it++))->type != TemplateToken::Type::EndFilter) {
Expand All @@ -2588,6 +2655,7 @@ class Parser {
} else if (dynamic_cast<EndForTemplateToken*>(token.get())
|| dynamic_cast<EndSetTemplateToken*>(token.get())
|| dynamic_cast<EndMacroTemplateToken*>(token.get())
|| dynamic_cast<EndCallTemplateToken*>(token.get())
|| dynamic_cast<EndFilterTemplateToken*>(token.get())
|| dynamic_cast<EndIfTemplateToken*>(token.get())
|| dynamic_cast<ElseTemplateToken*>(token.get())
Expand Down
2 changes: 1 addition & 1 deletion tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ set(MODEL_IDS
OnlyCheeini/greesychat-turbo
onnx-community/DeepSeek-R1-Distill-Qwen-1.5B-ONNX
open-thoughts/OpenThinker-7B
openbmb/MiniCPM3-4B
openchat/openchat-3.5-0106
Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2
OrionStarAI/Orion-14B-Chat
Expand Down Expand Up @@ -261,7 +262,6 @@ set(MODEL_IDS
prithivMLmods/Qwen2.5-7B-DeepSeek-R1-1M
prithivMLmods/QwQ-Math-IO-500M
prithivMLmods/Triangulum-v2-10B
qingy2024/Falcon3-2x10B-MoE-Instruct
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this removal accidental?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Qwen/QVQ-72B-Preview
Qwen/Qwen1.5-7B-Chat
Qwen/Qwen2-7B-Instruct
Expand Down
52 changes: 52 additions & 0 deletions tests/test-syntax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,54 @@ TEST(SyntaxTest, SimpleCases) {
{%- endmacro -%}
{{- foo() }} {{ foo() -}})", {}, {}));

EXPECT_EQ(
"x,x",
render(R"(
{%- macro test() -%}{{ caller() }},{{ caller() }}{%- endmacro -%}
{%- call test() -%}x{%- endcall -%}
)", {}, {}));

EXPECT_EQ(
"Outer[Inner(X)]",
render(R"(
{%- macro outer() -%}Outer[{{ caller() }}]{%- endmacro -%}
{%- macro inner() -%}Inner({{ caller() }}){%- endmacro -%}
{%- call outer() -%}{%- call inner() -%}X{%- endcall -%}{%- endcall -%}
)", {}, {}));

EXPECT_EQ(
"<ul><li>A</li><li>B</li></ul>",
render(R"(
{%- macro test(prefix, suffix) -%}{{ prefix }}{{ caller() }}{{ suffix }}{%- endmacro -%}
{%- set items = ["a", "b"] -%}
{%- call test("<ul>", "</ul>") -%}
{%- for item in items -%}
<li>{{ item | upper }}</li>
{%- endfor -%}
{%- endcall -%}
)", {}, {}));

EXPECT_EQ(
"\\n\\nclass A:\\n b: 1\\n c: 2\\n",
render(R"(
{%- macro recursive(obj) -%}
{%- set ns = namespace(content = caller()) -%}
{%- for key, value in obj.items() %}
{%- if value is mapping %}
{%- call recursive(value) -%}
{{ '\\n\\nclass ' + key.title() + ':\\n' }}
{%- endcall -%}
{%- else -%}
{%- set ns.content = ns.content + ' ' + key + ': ' + value + '\\n' -%}
{%- endif -%}
{%- endfor -%}
{{ ns.content }}
{%- endmacro -%}

{%- call recursive({"a": {"b": "1", "c": "2"}}) -%}
{%- endcall -%}
)", {}, {}));

if (!getenv("USE_JINJA2")) {
EXPECT_EQ(
"Foo",
Expand Down Expand Up @@ -584,6 +632,10 @@ TEST(SyntaxTest, SimpleCases) {
EXPECT_THAT([]() { render("{% if 1 %}{% else %}{% elif 1 %}{% endif %}", {}, {}); }, ThrowsWithSubstr("Unterminated if"));
EXPECT_THAT([]() { render("{% filter trim %}", {}, {}); }, ThrowsWithSubstr("Unterminated filter"));
EXPECT_THAT([]() { render("{# ", {}, {}); }, ThrowsWithSubstr("Missing end of comment tag"));

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please also add an unterminated call test:

EXPECT_THAT([]() { render("{%- call test -%}", {}, {}); }, ThrowsWithSubstr("Missing end of call tag"));

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests for call and macro added

EXPECT_THAT([]() {
render("{%- macro test() -%}content{%- endmacro -%}{%- call test -%}caller_content{%- endcall -%}", {}, {});
}, ThrowsWithSubstr("Invalid call block syntax - expected function call"));
}

EXPECT_EQ(
Expand Down
Loading