Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions include/minja/minja.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ class Value : public std::enable_shared_from_this<Value> {
}
} else if (is_object()) {
if (!index.is_hashable())
throw std::runtime_error("Unashable type: " + index.dump());
throw std::runtime_error("Unhashable type: " + index.dump());
auto it = object_->find(index.primitive_);
if (it == object_->end())
throw std::runtime_error("Key not found: " + index.dump());
Expand All @@ -252,7 +252,7 @@ class Value : public std::enable_shared_from_this<Value> {
auto index = key.get<int>();
return array_->at(index < 0 ? array_->size() + index : index);
} else if (object_) {
if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
if (!key.is_hashable()) throw std::runtime_error("Unhashable type: " + dump());
auto it = object_->find(key.primitive_);
if (it == object_->end()) return Value();
return it->second;
Expand All @@ -261,7 +261,7 @@ class Value : public std::enable_shared_from_this<Value> {
}
void set(const Value& key, const Value& value) {
if (!object_) throw std::runtime_error("Value is not an object: " + dump());
if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
if (!key.is_hashable()) throw std::runtime_error("Unhashable type: " + dump());
(*object_)[key.primitive_] = value;
}
Value call(const std::shared_ptr<Context> & context, ArgumentsValue & args) const {
Expand Down Expand Up @@ -398,7 +398,7 @@ class Value : public std::enable_shared_from_this<Value> {
}
return false;
} else if (object_) {
if (!value.is_hashable()) throw std::runtime_error("Unashable type: " + value.dump());
if (!value.is_hashable()) throw std::runtime_error("Unhashable type: " + value.dump());
return object_->find(value.primitive_) != object_->end();
} else {
throw std::runtime_error("contains can only be called on arrays and objects: " + dump());
Expand All @@ -416,7 +416,7 @@ class Value : public std::enable_shared_from_this<Value> {
return const_cast<Value*>(this)->at(index);
}
Value& at(const Value & index) {
if (!index.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
if (!index.is_hashable()) throw std::runtime_error("Unhashable type: " + dump());
if (is_array()) return array_->at(index.get<int>());
if (is_object()) return object_->at(index.primitive_);
throw std::runtime_error("Value is not an array or object: " + dump());
Expand Down
19 changes: 19 additions & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,22 @@ target_link_libraries(test-syntax PRIVATE
gmock
)

if (WIN32)
message(STATUS "Skipping test-chat-template on Win32")
else()
add_executable(test-chat-template test-chat-template.cpp)
target_compile_features(test-chat-template PUBLIC cxx_std_17)
if (CMAKE_SYSTEM_NAME STREQUAL "Windows" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
target_compile_definitions(test-chat-template PUBLIC _CRT_SECURE_NO_WARNINGS)
target_compile_options(gtest PRIVATE -Wno-language-extension-token)
endif()
target_link_libraries(test-chat-template PRIVATE
nlohmann_json::nlohmann_json
gtest_main
gmock
)
endif()

add_executable(test-polyfills test-polyfills.cpp)
target_compile_features(test-polyfills PUBLIC cxx_std_17)
if (CMAKE_SYSTEM_NAME STREQUAL "Windows" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
Expand All @@ -31,6 +47,9 @@ target_link_libraries(test-polyfills PRIVATE
)
if (NOT CMAKE_CROSSCOMPILING)
gtest_discover_tests(test-syntax)
if (NOT WIN32)
gtest_discover_tests(test-chat-template)
endif()
add_test(NAME test-polyfills COMMAND test-polyfills)
set_tests_properties(test-polyfills PROPERTIES WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
endif()
Expand Down
69 changes: 69 additions & 0 deletions tests/test-chat-template.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@

/*
Copyright 2024 Google LLC

Use of this source code is governed by an MIT-style
license that can be found in the LICENSE file or at
https://opensource.org/licenses/MIT.
*/
// SPDX-License-Identifier: MIT
#include "chat-template.hpp"
#include "gtest/gtest.h"
#include <gtest/gtest.h>
#include <gmock/gmock-matchers.h>

#include <fstream>
#include <iostream>
#include <string>

using namespace minja;
using namespace testing;

static std::string render_python(const std::string & template_str, const chat_template_inputs & inputs) {
json bindings = inputs.extra_context;
bindings["messages"] = inputs.messages;
bindings["tools"] = inputs.tools;
bindings["add_generation_prompt"] = inputs.add_generation_prompt;
json data {
{"template", template_str},
{"bindings", bindings},
{"options", {
{"trim_blocks", true},
{"lstrip_blocks", true},
{"keep_trailing_newline", false},
}},
};
{
std::ofstream of("data.json");
of << data.dump(2);
of.close();
}

auto pyExeEnv = getenv("PYTHON_EXECUTABLE");
std::string pyExe = pyExeEnv ? pyExeEnv : "python3";

std::remove("out.txt");
auto res = std::system((pyExe + " -m scripts.render data.json out.txt").c_str());
if (res != 0) {
throw std::runtime_error("Failed to run python script with data: " + data.dump(2));
}

std::ifstream f("out.txt");
std::string out((std::istreambuf_iterator<char>(f)), std::istreambuf_iterator<char>());
return out;
}

static std::string render(const std::string & template_str, const chat_template_inputs & inputs, const chat_template_options & opts) {
if (getenv("USE_JINJA2")) {
return render_python(template_str, inputs);
}
chat_template tmpl(
template_str,
"",
"");
return tmpl.apply(inputs, opts);
}

TEST(ChatTemplateTest, SimpleCases) {
EXPECT_THAT(render("{{ strftime_now('%Y-%m-%d %H:%M:%S') }}", {}, {}), MatchesRegex(R"([0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2})"));
}
Loading