Skip to content

Commit f28fe03

Browse files
authored
Merge pull request #5 from thammegowda/tg/chat-template
add chat template jinja rendering with minijinja
2 parents 71646b8 + 1428437 commit f28fe03

File tree

13 files changed

+1194
-49
lines changed

13 files changed

+1194
-49
lines changed

.gitmodules

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +0,0 @@
1-
[submodule "bindings/cpp/third_party/Jinja2Cpp"]
2-
path = bindings/cpp/third_party/Jinja2Cpp
3-
url = https://github.com/jinja2cpp/Jinja2Cpp.git

bindings/c/src/lib.rs

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -672,3 +672,124 @@ pub extern "C" fn tokenizers_get_chat_template(tokenizer: *mut c_void) -> *mut c
672672
ptr::null_mut()
673673
}
674674

675+
/// Apply a chat template to render messages
676+
///
677+
/// Arguments:
678+
/// - tokenizer: the tokenizer instance
679+
/// - template: Jinja2 template string
680+
/// - messages_json: JSON array of messages with "role" and "content" fields
681+
/// - add_generation_prompt: whether to append generation prompt
682+
/// - bos_token: optional BOS token string
683+
/// - eos_token: optional EOS token string
684+
/// - error_out: pointer to error string (caller must free with tokenizers_string_free)
685+
///
686+
/// Returns: rendered template string (caller must free with tokenizers_string_free), or null on error
687+
#[no_mangle]
688+
pub extern "C" fn tokenizers_apply_chat_template(
689+
tokenizer: *mut c_void,
690+
template: *const c_char,
691+
messages_json: *const c_char,
692+
add_generation_prompt: bool,
693+
bos_token: *const c_char,
694+
eos_token: *const c_char,
695+
error_out: *mut *mut c_char,
696+
) -> *mut c_char {
697+
if tokenizer.is_null() || template.is_null() || messages_json.is_null() {
698+
if !error_out.is_null() {
699+
let err = CString::new("Invalid arguments: null pointers provided").unwrap();
700+
unsafe { *error_out = err.into_raw(); }
701+
}
702+
return ptr::null_mut();
703+
}
704+
705+
let template_str = match unsafe { CStr::from_ptr(template) }.to_str() {
706+
Ok(s) => s,
707+
Err(_) => {
708+
if !error_out.is_null() {
709+
let err = CString::new("Invalid template string encoding").unwrap();
710+
unsafe { *error_out = err.into_raw(); }
711+
}
712+
return ptr::null_mut();
713+
}
714+
};
715+
716+
let messages_json_str = match unsafe { CStr::from_ptr(messages_json) }.to_str() {
717+
Ok(s) => s,
718+
Err(_) => {
719+
if !error_out.is_null() {
720+
let err = CString::new("Invalid messages JSON encoding").unwrap();
721+
unsafe { *error_out = err.into_raw(); }
722+
}
723+
return ptr::null_mut();
724+
}
725+
};
726+
727+
let bos_opt = if !bos_token.is_null() {
728+
match unsafe { CStr::from_ptr(bos_token) }.to_str() {
729+
Ok(s) => Some(s.to_string()),
730+
Err(_) => {
731+
if !error_out.is_null() {
732+
let err = CString::new("Invalid BOS token encoding").unwrap();
733+
unsafe { *error_out = err.into_raw(); }
734+
}
735+
return ptr::null_mut();
736+
}
737+
}
738+
} else {
739+
None
740+
};
741+
742+
let eos_opt = if !eos_token.is_null() {
743+
match unsafe { CStr::from_ptr(eos_token) }.to_str() {
744+
Ok(s) => Some(s.to_string()),
745+
Err(_) => {
746+
if !error_out.is_null() {
747+
let err = CString::new("Invalid EOS token encoding").unwrap();
748+
unsafe { *error_out = err.into_raw(); }
749+
}
750+
return ptr::null_mut();
751+
}
752+
}
753+
} else {
754+
None
755+
};
756+
757+
// Parse messages JSON
758+
let messages: Vec<tokenizers::Message> = match serde_json::from_str(messages_json_str) {
759+
Ok(msgs) => msgs,
760+
Err(e) => {
761+
if !error_out.is_null() {
762+
let err = CString::new(format!("Failed to parse messages JSON: {}", e)).unwrap();
763+
unsafe { *error_out = err.into_raw(); }
764+
}
765+
return ptr::null_mut();
766+
}
767+
};
768+
769+
// Create and apply chat template
770+
match tokenizers::ChatTemplate::new(template_str.to_string(), bos_opt, eos_opt) {
771+
Ok(chat_template) => {
772+
let inputs = tokenizers::ChatTemplateInputs::new(messages, add_generation_prompt);
773+
match chat_template.apply(inputs) {
774+
Ok(result) => {
775+
CString::new(result).unwrap().into_raw()
776+
}
777+
Err(e) => {
778+
if !error_out.is_null() {
779+
let err = CString::new(format!("Template rendering failed: {}", e)).unwrap();
780+
unsafe { *error_out = err.into_raw(); }
781+
}
782+
ptr::null_mut()
783+
}
784+
}
785+
}
786+
Err(e) => {
787+
if !error_out.is_null() {
788+
let err = CString::new(format!("Failed to compile template: {}", e)).unwrap();
789+
unsafe { *error_out = err.into_raw(); }
790+
}
791+
ptr::null_mut()
792+
}
793+
}
794+
}
795+

bindings/c/tokenizers_c.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,16 @@
88
extern "C" {
99
#endif
1010

11+
// Only define the struct if not already defined
12+
#ifndef TOKENIZERS_ENCODING_T_DEFINED
13+
#define TOKENIZERS_ENCODING_T_DEFINED
1114
typedef struct {
1215
const int* ids;
1316
const int* attention_mask;
1417
size_t len;
1518
void* _internal_ptr; // Internal use only - do not access
1619
} tokenizers_encoding_t;
20+
#endif
1721

1822
// Create a new tokenizer from a JSON file (auto-loads tokenizer_config.json if present)
1923
void* tokenizers_new_from_file(const char* path);
@@ -77,6 +81,26 @@ bool tokenizers_has_chat_template(void* tokenizer);
7781
// Get chat template string (must be freed with tokenizers_string_free)
7882
char* tokenizers_get_chat_template(void* tokenizer);
7983

84+
// Apply a chat template to render messages
85+
// Arguments:
86+
// - tokenizer: the tokenizer instance
87+
// - template_str: Jinja2 template string
88+
// - messages_json: JSON array of messages with "role" and "content" fields
89+
// - add_generation_prompt: whether to append generation prompt
90+
// - bos_token: optional BOS token string (can be NULL)
91+
// - eos_token: optional EOS token string (can be NULL)
92+
// - error_out: pointer to error string (caller must free with tokenizers_string_free)
93+
// Returns: rendered template string (caller must free with tokenizers_string_free), or NULL on error
94+
char* tokenizers_apply_chat_template(
95+
void* tokenizer,
96+
const char* template_str,
97+
const char* messages_json,
98+
bool add_generation_prompt,
99+
const char* bos_token,
100+
const char* eos_token,
101+
char** error_out
102+
);
103+
80104
#ifdef __cplusplus
81105
}
82106
#endif

bindings/cpp/CMakeLists.txt

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,6 @@ set(RUST_CRATE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../c)
1414
set(RUST_OUTPUT_DIR ${RUST_CRATE_DIR}/target/release)
1515
set(RUST_LIB_NAME tokenizers_c)
1616

17-
# Jinja2Cpp for chat template rendering
18-
set(JINJA2CPP_BUILD_TESTS OFF CACHE BOOL "" FORCE)
19-
set(JINJA2CPP_BUILD_SHARED OFF CACHE BOOL "" FORCE)
20-
set(JINJA2CPP_DEPS_MODE internal CACHE STRING "" FORCE)
21-
add_subdirectory(third_party/Jinja2Cpp)
22-
2317
# Custom command to build the Rust cdylib
2418
add_custom_command(
2519
OUTPUT ${RUST_OUTPUT_DIR}/lib${RUST_LIB_NAME}.so
@@ -43,8 +37,11 @@ add_library(tokenizers_cpp_impl STATIC
4337
src/tokenizers.cpp
4438
)
4539
add_dependencies(tokenizers_cpp_impl build_rust_ffi)
46-
target_include_directories(tokenizers_cpp_impl PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include)
47-
target_link_libraries(tokenizers_cpp_impl PUBLIC ${RUST_LIB_NAME} jinja2cpp)
40+
target_include_directories(tokenizers_cpp_impl
41+
PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include
42+
PRIVATE ${RUST_CRATE_DIR}
43+
)
44+
target_link_libraries(tokenizers_cpp_impl PUBLIC ${RUST_LIB_NAME})
4845

4946
# Interface library for easy linking
5047
add_library(tokenizers_cpp INTERFACE)
@@ -66,6 +63,7 @@ if(TOKENIZERS_COMPILE_TESTS)
6663
# Google Test executable
6764
add_executable(tokenizer_tests_gtest
6865
tests/test_tokenizer_gtest.cpp
66+
tests/test_tokenizer_chat_templates.cpp
6967
)
7068
target_link_libraries(tokenizer_tests_gtest PRIVATE tokenizers_cpp GTest::gtest_main)
7169

bindings/cpp/include/tokenizers/tokenizers.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,15 @@ extern "C" {
6363
bool tokenizers_get_add_eos_token(void* tokenizer);
6464
bool tokenizers_has_chat_template(void* tokenizer);
6565
char* tokenizers_get_chat_template(void* tokenizer);
66+
char* tokenizers_apply_chat_template(
67+
void* tokenizer,
68+
const char* template_str,
69+
const char* messages_json,
70+
bool add_generation_prompt,
71+
const char* bos_token,
72+
const char* eos_token,
73+
char** error_out
74+
);
6675
}
6776

6877
namespace tokenizers {
@@ -391,6 +400,18 @@ class Tokenizer {
391400
bool add_generation_prompt = true
392401
) const;
393402

403+
/// Apply custom chat template to messages
404+
/// @param template_str The Jinja2 chat template string to use
405+
/// @param messages Vector of ChatMessage with role and content
406+
/// @param add_generation_prompt If true, adds prompt for assistant response
407+
/// @return Formatted string ready for tokenization
408+
/// @throws ChatTemplateError if template rendering fails
409+
std::string apply_chat_template(
410+
const std::string& template_str,
411+
const std::vector<ChatMessage>& messages,
412+
bool add_generation_prompt = true
413+
) const;
414+
394415
bool valid() const { return handle_ != nullptr; }
395416

396417
static std::string version() {

bindings/cpp/src/tokenizers.cpp

Lines changed: 76 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,57 +3,96 @@
33
*/
44

55
#include <tokenizers/tokenizers.h>
6-
#include <jinja2cpp/template.h>
7-
#include <jinja2cpp/value.h>
6+
#include <sstream>
7+
#include <iomanip>
88

99
namespace tokenizers {
1010

11+
// Helper to escape JSON strings - handles special characters properly
12+
static std::string json_escape(const std::string& input) {
13+
std::string output;
14+
output.reserve(input.size() * 1.1); // Reserve extra space for escapes
15+
for (unsigned char c : input) {
16+
switch (c) {
17+
case '"': output += "\\\""; break;
18+
case '\\': output += "\\\\"; break;
19+
case '\b': output += "\\b"; break;
20+
case '\f': output += "\\f"; break;
21+
case '\n': output += "\\n"; break;
22+
case '\r': output += "\\r"; break;
23+
case '\t': output += "\\t"; break;
24+
default:
25+
if (c < 0x20) {
26+
// Control characters: escape as \uXXXX
27+
char buf[7];
28+
snprintf(buf, sizeof(buf), "\\u%04x", c);
29+
output += buf;
30+
} else {
31+
output += c;
32+
}
33+
}
34+
}
35+
return output;
36+
}
37+
1138
std::string Tokenizer::apply_chat_template(
39+
const std::string& template_str,
1240
const std::vector<ChatMessage>& messages,
1341
bool add_generation_prompt
1442
) const {
15-
// Get the template string
16-
std::string tmpl_str = chat_template();
17-
if (tmpl_str.empty()) {
18-
throw ChatTemplateError("No chat template available for this tokenizer");
43+
// Build messages JSON array manually
44+
std::stringstream ss;
45+
ss << "[";
46+
for (size_t i = 0; i < messages.size(); ++i) {
47+
if (i > 0) ss << ",";
48+
ss << "{\"role\":\"" << json_escape(messages[i].role)
49+
<< "\",\"content\":\"" << json_escape(messages[i].content) << "\"}";
1950
}
51+
ss << "]";
52+
std::string messages_json_str = ss.str();
2053

21-
// Create Jinja2 template
22-
jinja2::Template tpl;
23-
auto load_result = tpl.Load(tmpl_str, "chat_template");
24-
if (!load_result) {
25-
throw ChatTemplateError("Failed to parse chat template: " +
26-
load_result.error().ToString());
27-
}
54+
// Get special tokens (pass as C strings, can be null)
55+
std::string bos_str = bos_token();
56+
std::string eos_str = eos_token();
57+
const char* bos_ptr = bos_str.empty() ? nullptr : bos_str.c_str();
58+
const char* eos_ptr = eos_str.empty() ? nullptr : eos_str.c_str();
2859

29-
// Convert messages to Jinja2 values
30-
jinja2::ValuesList jinja_messages;
31-
for (const auto& msg : messages) {
32-
jinja2::ValuesMap msg_map;
33-
msg_map["role"] = msg.role;
34-
msg_map["content"] = msg.content;
35-
jinja_messages.push_back(std::move(msg_map));
36-
}
60+
// Call C FFI function with custom template
61+
char* error_msg = nullptr;
62+
char* result = tokenizers_apply_chat_template(
63+
handle_,
64+
template_str.c_str(),
65+
messages_json_str.c_str(),
66+
add_generation_prompt,
67+
bos_ptr,
68+
eos_ptr,
69+
&error_msg
70+
);
3771

38-
// Build parameters map
39-
jinja2::ValuesMap params;
40-
params["messages"] = std::move(jinja_messages);
41-
params["add_generation_prompt"] = add_generation_prompt;
72+
if (result == nullptr) {
73+
std::string error = error_msg ? error_msg : "Failed to apply chat template";
74+
if (error_msg) {
75+
tokenizers_string_free(error_msg);
76+
}
77+
throw ChatTemplateError(error);
78+
}
4279

43-
// Add special tokens as variables (commonly used in templates)
44-
params["bos_token"] = bos_token();
45-
params["eos_token"] = eos_token();
46-
params["pad_token"] = pad_token();
47-
params["unk_token"] = unk_token();
80+
std::string rendered(result);
81+
tokenizers_string_free(result);
4882

49-
// Render the template
50-
auto render_result = tpl.RenderAsString(params);
51-
if (!render_result) {
52-
throw ChatTemplateError("Failed to render chat template: " +
53-
render_result.error().ToString());
83+
return rendered;
84+
}
85+
86+
std::string Tokenizer::apply_chat_template(
87+
const std::vector<ChatMessage>& messages,
88+
bool add_generation_prompt
89+
) const {
90+
// Get the template string from config and delegate to the overload
91+
std::string tmpl_str = chat_template();
92+
if (tmpl_str.empty()) {
93+
throw ChatTemplateError("No chat template available for this tokenizer");
5494
}
55-
56-
return render_result.value();
95+
return apply_chat_template(tmpl_str, messages, add_generation_prompt);
5796
}
5897

5998
} // namespace tokenizers

0 commit comments

Comments
 (0)