Skip to content

Commit b810bce

Browse files
authored
Test strftime_now (+ nit typo fixes) (#59)
* Test strftime_now in new test-chat-template * opportunistic typo fixes (Unashable -> Unhashable) * avoid gtest regex weirdness on win32
1 parent 84187ba commit b810bce

File tree

3 files changed

+93
-5
lines changed

3 files changed

+93
-5
lines changed

include/minja/minja.hpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ class Value : public std::enable_shared_from_this<Value> {
233233
}
234234
} else if (is_object()) {
235235
if (!index.is_hashable())
236-
throw std::runtime_error("Unashable type: " + index.dump());
236+
throw std::runtime_error("Unhashable type: " + index.dump());
237237
auto it = object_->find(index.primitive_);
238238
if (it == object_->end())
239239
throw std::runtime_error("Key not found: " + index.dump());
@@ -252,7 +252,7 @@ class Value : public std::enable_shared_from_this<Value> {
252252
auto index = key.get<int>();
253253
return array_->at(index < 0 ? array_->size() + index : index);
254254
} else if (object_) {
255-
if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
255+
if (!key.is_hashable()) throw std::runtime_error("Unhashable type: " + dump());
256256
auto it = object_->find(key.primitive_);
257257
if (it == object_->end()) return Value();
258258
return it->second;
@@ -261,7 +261,7 @@ class Value : public std::enable_shared_from_this<Value> {
261261
}
262262
void set(const Value& key, const Value& value) {
263263
if (!object_) throw std::runtime_error("Value is not an object: " + dump());
264-
if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
264+
if (!key.is_hashable()) throw std::runtime_error("Unhashable type: " + dump());
265265
(*object_)[key.primitive_] = value;
266266
}
267267
Value call(const std::shared_ptr<Context> & context, ArgumentsValue & args) const {
@@ -398,7 +398,7 @@ class Value : public std::enable_shared_from_this<Value> {
398398
}
399399
return false;
400400
} else if (object_) {
401-
if (!value.is_hashable()) throw std::runtime_error("Unashable type: " + value.dump());
401+
if (!value.is_hashable()) throw std::runtime_error("Unhashable type: " + value.dump());
402402
return object_->find(value.primitive_) != object_->end();
403403
} else {
404404
throw std::runtime_error("contains can only be called on arrays and objects: " + dump());
@@ -416,7 +416,7 @@ class Value : public std::enable_shared_from_this<Value> {
416416
return const_cast<Value*>(this)->at(index);
417417
}
418418
Value& at(const Value & index) {
419-
if (!index.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
419+
if (!index.is_hashable()) throw std::runtime_error("Unhashable type: " + dump());
420420
if (is_array()) return array_->at(index.get<int>());
421421
if (is_object()) return object_->at(index.primitive_);
422422
throw std::runtime_error("Value is not an array or object: " + dump());

tests/CMakeLists.txt

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,22 @@ target_link_libraries(test-syntax PRIVATE
1818
gmock
1919
)
2020

21+
if (WIN32)
22+
message(STATUS "Skipping test-chat-template on Win32")
23+
else()
24+
add_executable(test-chat-template test-chat-template.cpp)
25+
target_compile_features(test-chat-template PUBLIC cxx_std_17)
26+
if (CMAKE_SYSTEM_NAME STREQUAL "Windows" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
27+
target_compile_definitions(test-chat-template PUBLIC _CRT_SECURE_NO_WARNINGS)
28+
target_compile_options(gtest PRIVATE -Wno-language-extension-token)
29+
endif()
30+
target_link_libraries(test-chat-template PRIVATE
31+
nlohmann_json::nlohmann_json
32+
gtest_main
33+
gmock
34+
)
35+
endif()
36+
2137
add_executable(test-polyfills test-polyfills.cpp)
2238
target_compile_features(test-polyfills PUBLIC cxx_std_17)
2339
if (CMAKE_SYSTEM_NAME STREQUAL "Windows" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
@@ -31,6 +47,9 @@ target_link_libraries(test-polyfills PRIVATE
3147
)
3248
if (NOT CMAKE_CROSSCOMPILING)
3349
gtest_discover_tests(test-syntax)
50+
if (NOT WIN32)
51+
gtest_discover_tests(test-chat-template)
52+
endif()
3453
add_test(NAME test-polyfills COMMAND test-polyfills)
3554
set_tests_properties(test-polyfills PROPERTIES WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
3655
endif()

tests/test-chat-template.cpp

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
2+
/*
3+
Copyright 2024 Google LLC
4+
5+
Use of this source code is governed by an MIT-style
6+
license that can be found in the LICENSE file or at
7+
https://opensource.org/licenses/MIT.
8+
*/
9+
// SPDX-License-Identifier: MIT
10+
#include "chat-template.hpp"
11+
#include "gtest/gtest.h"
12+
#include <gtest/gtest.h>
13+
#include <gmock/gmock-matchers.h>
14+
15+
#include <fstream>
16+
#include <iostream>
17+
#include <string>
18+
19+
using namespace minja;
20+
using namespace testing;
21+
22+
static std::string render_python(const std::string & template_str, const chat_template_inputs & inputs) {
23+
json bindings = inputs.extra_context;
24+
bindings["messages"] = inputs.messages;
25+
bindings["tools"] = inputs.tools;
26+
bindings["add_generation_prompt"] = inputs.add_generation_prompt;
27+
json data {
28+
{"template", template_str},
29+
{"bindings", bindings},
30+
{"options", {
31+
{"trim_blocks", true},
32+
{"lstrip_blocks", true},
33+
{"keep_trailing_newline", false},
34+
}},
35+
};
36+
{
37+
std::ofstream of("data.json");
38+
of << data.dump(2);
39+
of.close();
40+
}
41+
42+
auto pyExeEnv = getenv("PYTHON_EXECUTABLE");
43+
std::string pyExe = pyExeEnv ? pyExeEnv : "python3";
44+
45+
std::remove("out.txt");
46+
auto res = std::system((pyExe + " -m scripts.render data.json out.txt").c_str());
47+
if (res != 0) {
48+
throw std::runtime_error("Failed to run python script with data: " + data.dump(2));
49+
}
50+
51+
std::ifstream f("out.txt");
52+
std::string out((std::istreambuf_iterator<char>(f)), std::istreambuf_iterator<char>());
53+
return out;
54+
}
55+
56+
static std::string render(const std::string & template_str, const chat_template_inputs & inputs, const chat_template_options & opts) {
57+
if (getenv("USE_JINJA2")) {
58+
return render_python(template_str, inputs);
59+
}
60+
chat_template tmpl(
61+
template_str,
62+
"",
63+
"");
64+
return tmpl.apply(inputs, opts);
65+
}
66+
67+
TEST(ChatTemplateTest, SimpleCases) {
68+
EXPECT_THAT(render("{{ strftime_now('%Y-%m-%d %H:%M:%S') }}", {}, {}), MatchesRegex(R"([0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2})"));
69+
}

0 commit comments

Comments
 (0)