Skip to content

Commit f688768

Browse files
authored
Merge pull request #16 from InfiniTensor/dev
重构 chat-template 和 tokenizer
2 parents 30578bf + 227edbf commit f688768

File tree

34 files changed

+1081
-463
lines changed

34 files changed

+1081
-463
lines changed

Cargo.lock

Lines changed: 151 additions & 75 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ members = [
44
"tensor",
55
"tokenizer",
66
"causal-lm",
7+
"chat-template",
78
"service",
89
"web-api",
910
"xtask",
@@ -12,6 +13,7 @@ members = [
1213
"devices/common-cpu",
1314
"devices/nvidia-gpu",
1415
"devices/cambricon-mlu",
16+
"devices/ascend-card",
1517

1618
"models/llama/common",
1719
"models/llama/common-cpu",
@@ -34,6 +36,7 @@ tokio = { version = "1.38", features = ["rt-multi-thread", "sync"] }
3436
digit-layout = "0.0"
3537
build-script-cfg = "0.0"
3638

37-
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "5a88159", default-features = false }
38-
search-cuda-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "fb088b6" }
39+
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "e6ee6ea", default-features = false }
40+
search-cuda-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "d089ada" }
3941
search-neuware-tools = "0.0"
42+
search-ascend-tools = { git = "https://github.com/InfiniTensor/ascendcl", rev = "1e7a696" }

causal-lm/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ pub trait CausalLM: Model {
4343
type Storage;
4444
/// 最大序列长度。
4545
fn max_seq_len(&self) -> upos;
46+
/// 模型定义的句子起始符。
47+
fn bos_token(&self) -> utok;
4648
/// 模型定义的句子结束符。
4749
fn eos_token(&self) -> utok;
4850
/// 创建一个未填充的缓存张量(`num_layers x 2 x num_kv_head x max_seq_len x head_dim`)。

chat-template/Cargo.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
[package]
2+
name = "chat-template"
3+
version = "0.0.0"
4+
edition = "2021"
5+
authors = ["YdrMaster <[email protected]>"]
6+
7+
[dependencies]
8+
serde = { workspace = true, features = ["derive"] }
9+
minijinja = { version = "2.1", default-features = false, features = ["loader"] }

chat-template/src/lib.rs

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
#![deny(warnings)]
2+
3+
use minijinja::Environment;
4+
use serde::Serialize;
5+
use std::sync::{
6+
atomic::{AtomicUsize, Ordering::Relaxed},
7+
OnceLock, RwLock,
8+
};
9+
10+
#[repr(transparent)]
11+
pub struct ChatTemplate(String);
12+
13+
#[derive(Serialize)]
14+
pub struct Message<'a> {
15+
pub role: &'a str,
16+
pub content: &'a str,
17+
}
18+
19+
impl ChatTemplate {
20+
pub fn new(template: String) -> Self {
21+
static NEXT: AtomicUsize = AtomicUsize::new(0);
22+
let id = NEXT.fetch_add(1, Relaxed).to_string();
23+
24+
jinja()
25+
.write()
26+
.unwrap()
27+
.add_template_owned(id.clone(), template)
28+
.unwrap();
29+
30+
Self(id)
31+
}
32+
33+
pub fn render(
34+
&self,
35+
messages: &[Message<'_>],
36+
bos_token: &str,
37+
eos_token: &str,
38+
add_generation_prompt: bool,
39+
) -> Result<String, minijinja::Error> {
40+
#[derive(Serialize)]
41+
struct Args<'a> {
42+
messages: &'a [Message<'a>],
43+
bos_token: &'a str,
44+
eos_token: &'a str,
45+
add_generation_prompt: bool,
46+
}
47+
48+
jinja()
49+
.read()
50+
.unwrap()
51+
.get_template(&self.0)
52+
.unwrap()
53+
.render(Args {
54+
messages,
55+
bos_token,
56+
eos_token,
57+
add_generation_prompt,
58+
})
59+
}
60+
}
61+
62+
impl Drop for ChatTemplate {
63+
fn drop(&mut self) {
64+
jinja().write().unwrap().remove_template(&self.0);
65+
}
66+
}
67+
68+
fn jinja() -> &'static RwLock<Environment<'static>> {
69+
static ENV: OnceLock<RwLock<Environment<'_>>> = OnceLock::new();
70+
ENV.get_or_init(|| {
71+
let mut env = Environment::empty();
72+
env.set_unknown_method_callback(|_, value, method, args| {
73+
use minijinja::{value::ValueKind as ThisType, ErrorKind::UnknownMethod, Value};
74+
match (method, value.kind(), args) {
75+
("strip", ThisType::String, []) => Ok(Value::from_safe_string(
76+
value.to_str().unwrap().trim().into(),
77+
)),
78+
_ => Err(UnknownMethod.into()),
79+
}
80+
});
81+
RwLock::new(env)
82+
})
83+
}
84+
85+
#[test]
86+
fn test() {
87+
const TAIDE: &str = "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = '<<SYS>>\n' + messages[0]['content'] + '\n<</SYS>>\n\n' %}{% else %}{% set loop_messages = messages %}{% set system_message = '' %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{% set content = system_message + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content + ' [/INST]'}}{% elif message['role'] == 'assistant' %}{{ ' ' + content + ' ' + eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}";
88+
const MINICPM: &str = "{% for message in messages %}{% if message['role'] == 'user' %}{{'<用户>' + message['content'].strip() + '<AI>'}}{% else %}{{message['content'].strip()}}{% endif %}{% endfor %}";
89+
90+
let result = ChatTemplate::new(TAIDE.into())
91+
.render(
92+
&[Message {
93+
role: "user",
94+
content: "Hello, who are you?",
95+
}],
96+
"<s>",
97+
"</s>",
98+
true,
99+
)
100+
.unwrap();
101+
102+
assert_eq!(
103+
result,
104+
"<s>[INST] Hello, who are you? [/INST]<|im_start|>assistant\n"
105+
);
106+
107+
let result = ChatTemplate::new(MINICPM.into())
108+
.render(
109+
&[Message {
110+
role: "user",
111+
content: "Hello, who are you?",
112+
}],
113+
"<s>",
114+
"</s>",
115+
true,
116+
)
117+
.unwrap();
118+
assert_eq!(result, "<用户>Hello, who are you?<AI>");
119+
}

common/src/between_f32.rs

Lines changed: 0 additions & 39 deletions
This file was deleted.

common/src/lib.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,10 @@ pub type utok = u32;
1010
#[allow(non_camel_case_types)]
1111
pub type upos = u32;
1212

13-
mod between_f32;
1413
mod blob;
1514
pub mod safe_tensors;
1615
pub mod test_model;
1716

18-
pub use between_f32::BetweenF32;
1917
pub use blob::Blob;
2018
pub use half::{bf16, f16};
2119

devices/ascend-card/Cargo.toml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
[package]
2+
name = "common-acl"
3+
version = "0.0.0"
4+
edition = "2021"
5+
authors = ["YdrMaster <[email protected]>"]
6+
7+
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
8+
9+
[dependencies]
10+
common = { path = "../../common" }
11+
common-devices = { path = "../common" }
12+
tensor = { path = "../../tensor" }
13+
operators = { workspace = true, features = ["ascend-card"] }
14+
15+
[build-dependencies]
16+
build-script-cfg.workspace = true
17+
search-ascend-tools.workspace = true

devices/ascend-card/build.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
fn main() {
2+
use build_script_cfg::Cfg;
3+
use search_ascend_tools::find_ascend_toolkit_home;
4+
5+
let ascend = Cfg::new("detected_ascend");
6+
if find_ascend_toolkit_home().is_some() {
7+
ascend.define();
8+
}
9+
}

devices/ascend-card/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#![cfg(detected_ascend)]
2+
3+
pub use operators::ascendcl;
4+
pub use tensor::Tensor;

0 commit comments

Comments
 (0)