Skip to content

Commit a065fcf

Browse files
Support Qwen3 (str.startswith() and [::-1]) (google#66)
* Add str.startswith() * Add support for step=-1 in slice * Add Qwen3 template * Clamp out-of-bounds slice indices * Simplify subscript logic + handle any (non-zero) step --------- Co-authored-by: Olivier Chafik <ochafik@users.noreply.github.com>
1 parent b3fca45 commit a065fcf

File tree

3 files changed

+73
-30
lines changed

3 files changed

+73
-30
lines changed

include/minja/minja.hpp

Lines changed: 60 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1201,9 +1201,9 @@ class DictExpr : public Expression {
12011201

12021202
class SliceExpr : public Expression {
12031203
public:
1204-
std::shared_ptr<Expression> start, end;
1205-
SliceExpr(const Location & loc, std::shared_ptr<Expression> && s, std::shared_ptr<Expression> && e)
1206-
: Expression(loc), start(std::move(s)), end(std::move(e)) {}
1204+
std::shared_ptr<Expression> start, end, step;
1205+
SliceExpr(const Location & loc, std::shared_ptr<Expression> && s, std::shared_ptr<Expression> && e, std::shared_ptr<Expression> && st = nullptr)
1206+
: Expression(loc), start(std::move(s)), end(std::move(e)), step(std::move(st)) {}
12071207
Value do_evaluate(const std::shared_ptr<Context> &) const override {
12081208
throw std::runtime_error("SliceExpr not implemented");
12091209
}
@@ -1220,18 +1220,35 @@ class SubscriptExpr : public Expression {
12201220
if (!index) throw std::runtime_error("SubscriptExpr.index is null");
12211221
auto target_value = base->evaluate(context);
12221222
if (auto slice = dynamic_cast<SliceExpr*>(index.get())) {
1223-
auto start = slice->start ? slice->start->evaluate(context).get<int64_t>() : 0;
1224-
auto end = slice->end ? slice->end->evaluate(context).get<int64_t>() : (int64_t) target_value.size();
1223+
auto len = target_value.size();
1224+
auto wrap = [len](int64_t i) -> int64_t {
1225+
if (i < 0) {
1226+
return i + len;
1227+
}
1228+
return i;
1229+
};
1230+
int64_t step = slice->step ? slice->step->evaluate(context).get<int64_t>() : 1;
1231+
if (!step) {
1232+
throw std::runtime_error("slice step cannot be zero");
1233+
}
1234+
int64_t start = slice->start ? wrap(slice->start->evaluate(context).get<int64_t>()) : (step < 0 ? len - 1 : 0);
1235+
int64_t end = slice->end ? wrap(slice->end->evaluate(context).get<int64_t>()) : (step < 0 ? -1 : len);
12251236
if (target_value.is_string()) {
12261237
std::string s = target_value.get<std::string>();
1227-
if (start < 0) start = s.size() + start;
1228-
if (end < 0) end = s.size() + end;
1229-
return s.substr(start, end - start);
1230-
} else if (target_value.is_array()) {
1231-
if (start < 0) start = target_value.size() + start;
1232-
if (end < 0) end = target_value.size() + end;
1238+
1239+
std::string result;
1240+
if (start < end && step == 1) {
1241+
result = s.substr(start, end - start);
1242+
} else {
1243+
for (int64_t i = start; step > 0 ? i < end : i > end; i += step) {
1244+
result += s[i];
1245+
}
1246+
}
1247+
return result;
1248+
1249+
} else if (target_value.is_array()) {
12331250
auto result = Value::array();
1234-
for (auto i = start; i < end; ++i) {
1251+
for (int64_t i = start; step > 0 ? i < end : i > end; i += step) {
12351252
result.push_back(target_value.at(i));
12361253
}
12371254
return result;
@@ -1523,6 +1540,10 @@ class MethodCallExpr : public Expression {
15231540
vargs.expectArgs("endswith method", {1, 1}, {0, 0});
15241541
auto suffix = vargs.args[0].get<std::string>();
15251542
return suffix.length() <= str.length() && std::equal(suffix.rbegin(), suffix.rend(), str.rbegin());
1543+
} else if (method->get_name() == "startswith") {
1544+
vargs.expectArgs("startswith method", {1, 1}, {0, 0});
1545+
auto prefix = vargs.args[0].get<std::string>();
1546+
return prefix.length() <= str.length() && std::equal(prefix.begin(), prefix.end(), str.begin());
15261547
} else if (method->get_name() == "title") {
15271548
vargs.expectArgs("title method", {0, 0}, {0, 0});
15281549
auto res = str;
@@ -2085,28 +2106,37 @@ class Parser {
20852106

20862107
while (it != end && consumeSpaces() && peekSymbols({ "[", "." })) {
20872108
if (!consumeToken("[").empty()) {
2088-
std::shared_ptr<Expression> index;
2109+
std::shared_ptr<Expression> index;
2110+
auto slice_loc = get_location();
2111+
std::shared_ptr<Expression> start, end, step;
2112+
bool has_first_colon = false, has_second_colon = false;
2113+
2114+
if (!peekSymbols({ ":" })) {
2115+
start = parseExpression();
2116+
}
2117+
2118+
if (!consumeToken(":").empty()) {
2119+
has_first_colon = true;
2120+
if (!peekSymbols({ ":", "]" })) {
2121+
end = parseExpression();
2122+
}
20892123
if (!consumeToken(":").empty()) {
2090-
auto slice_end = parseExpression();
2091-
index = std::make_shared<SliceExpr>(slice_end->location, nullptr, std::move(slice_end));
2092-
} else {
2093-
auto slice_start = parseExpression();
2094-
if (!consumeToken(":").empty()) {
2095-
consumeSpaces();
2096-
if (peekSymbols({ "]" })) {
2097-
index = std::make_shared<SliceExpr>(slice_start->location, std::move(slice_start), nullptr);
2098-
} else {
2099-
auto slice_end = parseExpression();
2100-
index = std::make_shared<SliceExpr>(slice_start->location, std::move(slice_start), std::move(slice_end));
2101-
}
2102-
} else {
2103-
index = std::move(slice_start);
2124+
has_second_colon = true;
2125+
if (!peekSymbols({ "]" })) {
2126+
step = parseExpression();
21042127
}
21052128
}
2106-
if (!index) throw std::runtime_error("Empty index in subscript");
2107-
if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript");
2129+
}
2130+
2131+
if ((has_first_colon || has_second_colon) && (start || end || step)) {
2132+
index = std::make_shared<SliceExpr>(slice_loc, std::move(start), std::move(end), std::move(step));
2133+
} else {
2134+
index = std::move(start);
2135+
}
2136+
if (!index) throw std::runtime_error("Empty index in subscript");
2137+
if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript");
21082138

2109-
value = std::make_shared<SubscriptExpr>(value->location, std::move(value), std::move(index));
2139+
value = std::make_shared<SubscriptExpr>(value->location, std::move(value), std::move(index));
21102140
} else if (!consumeToken(".").empty()) {
21112141
auto identifier = parseIdentifier();
21122142
if (!identifier) throw std::runtime_error("Expected identifier in subscript");

tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,7 @@ set(MODEL_IDS
318318
ValiantLabs/Llama3.1-8B-Enigma
319319
xwen-team/Xwen-72B-Chat
320320
xwen-team/Xwen-7B-Chat
321+
Qwen/Qwen3-4B
321322

322323
# Broken, TODO:
323324
# ai21labs/AI21-Jamba-1.5-Large # https://github.com/google/minja/issues/8

tests/test-syntax.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,9 @@ TEST(SyntaxTest, SimpleCases) {
184184
EXPECT_EQ(
185185
"1",
186186
render(R"({{ 1 | safe }})", {}, {}));
187+
EXPECT_EQ(
188+
"True,False",
189+
render(R"({{ 'abc'.startswith('ab') }},{{ ''.startswith('a') }})", {}, {}));
187190
EXPECT_EQ(
188191
"True,False",
189192
render(R"({{ 'abc'.endswith('bc') }},{{ ''.endswith('a') }})", {}, {}));
@@ -477,6 +480,15 @@ TEST(SyntaxTest, SimpleCases) {
477480
EXPECT_EQ(
478481
"[1, 2, 3][0, 1][1, 2]",
479482
render("{% set x = [0, 1, 2, 3] %}{{ x[1:] }}{{ x[:2] }}{{ x[1:3] }}", {}, {}));
483+
EXPECT_EQ(
484+
"123;01;12",
485+
render("{% set x = '0123' %}{{ x[1:] }};{{ x[:2] }};{{ x[1:3] }}", {}, {}));
486+
EXPECT_EQ(
487+
"[3, 2, 1, 0][3, 2, 1][2, 1, 0][2, 1][0, 2][3, 1][2, 0]",
488+
render("{% set x = [0, 1, 2, 3] %}{{ x[::-1] }}{{ x[:0:-1] }}{{ x[2::-1] }}{{ x[2:0:-1] }}{{ x[::2] }}{{ x[::-2] }}{{ x[-2::-2] }}", {}, {}));
489+
EXPECT_EQ(
490+
"3210;321;210;21;02;31;20",
491+
render("{% set x = '0123' %}{{ x[::-1] }};{{ x[:0:-1] }};{{ x[2::-1] }};{{ x[2:0:-1] }};{{ x[::2] }};{{ x[::-2] }};{{ x[-2::-2] }}", {}, {}));
480492
EXPECT_EQ(
481493
"a",
482494
render("{{ ' a ' | trim }}", {}, {}));

0 commit comments

Comments
 (0)