|
| 1 | +#![deny(warnings)] |
| 2 | + |
| 3 | +use minijinja::Environment; |
| 4 | +use serde::Serialize; |
| 5 | +use std::sync::{ |
| 6 | + atomic::{AtomicUsize, Ordering::Relaxed}, |
| 7 | + OnceLock, RwLock, |
| 8 | +}; |
| 9 | + |
| 10 | +#[repr(transparent)] |
| 11 | +pub struct ChatTemplate(String); |
| 12 | + |
| 13 | +#[derive(Serialize)] |
| 14 | +pub struct Message<'a> { |
| 15 | + pub role: &'a str, |
| 16 | + pub content: &'a str, |
| 17 | +} |
| 18 | + |
| 19 | +impl ChatTemplate { |
| 20 | + pub fn new(template: String) -> Self { |
| 21 | + static NEXT: AtomicUsize = AtomicUsize::new(0); |
| 22 | + let id = NEXT.fetch_add(1, Relaxed).to_string(); |
| 23 | + |
| 24 | + jinja() |
| 25 | + .write() |
| 26 | + .unwrap() |
| 27 | + .add_template_owned(id.clone(), template) |
| 28 | + .unwrap(); |
| 29 | + |
| 30 | + Self(id) |
| 31 | + } |
| 32 | + |
| 33 | + pub fn render( |
| 34 | + &self, |
| 35 | + messages: &[Message<'_>], |
| 36 | + bos_token: &str, |
| 37 | + eos_token: &str, |
| 38 | + add_generation_prompt: bool, |
| 39 | + ) -> Result<String, minijinja::Error> { |
| 40 | + #[derive(Serialize)] |
| 41 | + struct Args<'a> { |
| 42 | + messages: &'a [Message<'a>], |
| 43 | + bos_token: &'a str, |
| 44 | + eos_token: &'a str, |
| 45 | + add_generation_prompt: bool, |
| 46 | + } |
| 47 | + |
| 48 | + jinja() |
| 49 | + .read() |
| 50 | + .unwrap() |
| 51 | + .get_template(&self.0) |
| 52 | + .unwrap() |
| 53 | + .render(Args { |
| 54 | + messages, |
| 55 | + bos_token, |
| 56 | + eos_token, |
| 57 | + add_generation_prompt, |
| 58 | + }) |
| 59 | + } |
| 60 | +} |
| 61 | + |
| 62 | +impl Drop for ChatTemplate { |
| 63 | + fn drop(&mut self) { |
| 64 | + jinja().write().unwrap().remove_template(&self.0); |
| 65 | + } |
| 66 | +} |
| 67 | + |
| 68 | +fn jinja() -> &'static RwLock<Environment<'static>> { |
| 69 | + static ENV: OnceLock<RwLock<Environment<'_>>> = OnceLock::new(); |
| 70 | + ENV.get_or_init(|| { |
| 71 | + let mut env = Environment::empty(); |
| 72 | + env.set_unknown_method_callback(|_, value, method, args| { |
| 73 | + use minijinja::{value::ValueKind as ThisType, ErrorKind::UnknownMethod, Value}; |
| 74 | + match (method, value.kind(), args) { |
| 75 | + ("strip", ThisType::String, []) => Ok(Value::from_safe_string( |
| 76 | + value.to_str().unwrap().trim().into(), |
| 77 | + )), |
| 78 | + _ => Err(UnknownMethod.into()), |
| 79 | + } |
| 80 | + }); |
| 81 | + RwLock::new(env) |
| 82 | + }) |
| 83 | +} |
| 84 | + |
| 85 | +#[test] |
| 86 | +fn test() { |
| 87 | + const TAIDE: &str = "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = '<<SYS>>\n' + messages[0]['content'] + '\n<</SYS>>\n\n' %}{% else %}{% set loop_messages = messages %}{% set system_message = '' %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{% set content = system_message + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content + ' [/INST]'}}{% elif message['role'] == 'assistant' %}{{ ' ' + content + ' ' + eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"; |
| 88 | + const MINICPM: &str = "{% for message in messages %}{% if message['role'] == 'user' %}{{'<用户>' + message['content'].strip() + '<AI>'}}{% else %}{{message['content'].strip()}}{% endif %}{% endfor %}"; |
| 89 | + |
| 90 | + let result = ChatTemplate::new(TAIDE.into()) |
| 91 | + .render( |
| 92 | + &[Message { |
| 93 | + role: "user", |
| 94 | + content: "Hello, who are you?", |
| 95 | + }], |
| 96 | + "<s>", |
| 97 | + "</s>", |
| 98 | + true, |
| 99 | + ) |
| 100 | + .unwrap(); |
| 101 | + |
| 102 | + assert_eq!( |
| 103 | + result, |
| 104 | + "<s>[INST] Hello, who are you? [/INST]<|im_start|>assistant\n" |
| 105 | + ); |
| 106 | + |
| 107 | + let result = ChatTemplate::new(MINICPM.into()) |
| 108 | + .render( |
| 109 | + &[Message { |
| 110 | + role: "user", |
| 111 | + content: "Hello, who are you?", |
| 112 | + }], |
| 113 | + "<s>", |
| 114 | + "</s>", |
| 115 | + true, |
| 116 | + ) |
| 117 | + .unwrap(); |
| 118 | + assert_eq!(result, "<用户>Hello, who are you?<AI>"); |
| 119 | +} |
0 commit comments