Skip to content

Commit d68a0eb

Browse files
committed
feat(causal-lm): 识别是否需要替换空格,恢复添加和替换空格机制
Signed-off-by: YdrMaster <[email protected]>
1 parent 1973925 commit d68a0eb

File tree

6 files changed

+235
-224
lines changed

6 files changed

+235
-224
lines changed
Lines changed: 40 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::Tokenize;
1+
use crate::Tokenizer;
22
use common::GGufModel;
33
use minijinja::Environment;
44
use serde::Serialize;
@@ -21,39 +21,38 @@ pub struct Message<'a> {
2121
pub content: &'a str,
2222
}
2323

24-
/// Build a chat template from the GGuf model.
25-
pub fn build_render(gguf: &GGufModel, tokenize: &dyn Tokenize) -> Option<ChatTemplate> {
26-
let template = gguf
27-
.meta_kvs
28-
.get("tokenizer.chat_template")?
29-
.value_reader()
30-
.read_str()
31-
.unwrap()
32-
.into();
33-
34-
let bos = gguf.meta_kvs["tokenizer.ggml.bos_token_id"]
35-
.value_reader()
36-
.read::<utok>()
37-
.unwrap();
38-
let eos = gguf.meta_kvs["tokenizer.ggml.eos_token_id"]
39-
.value_reader()
40-
.read::<utok>()
41-
.unwrap();
24+
impl ChatTemplate {
25+
pub fn from_gguf(gguf: &GGufModel, tokenize: &Tokenizer) -> Option<ChatTemplate> {
26+
let template = gguf
27+
.meta_kvs
28+
.get("tokenizer.chat_template")?
29+
.value_reader()
30+
.read_str()
31+
.unwrap()
32+
.into();
4233

43-
Some(ChatTemplate::new(
44-
template,
45-
tokenize.decode(bos).into(),
46-
tokenize.decode(eos).into(),
47-
))
48-
}
34+
let bos = gguf.meta_kvs["tokenizer.ggml.bos_token_id"]
35+
.value_reader()
36+
.read::<utok>()
37+
.unwrap();
38+
let eos = gguf.meta_kvs["tokenizer.ggml.eos_token_id"]
39+
.value_reader()
40+
.read::<utok>()
41+
.unwrap();
42+
43+
Some(ChatTemplate::new(
44+
template,
45+
tokenize.decode(bos).into(),
46+
tokenize.decode(eos).into(),
47+
))
48+
}
4949

50-
impl ChatTemplate {
5150
/// Create a new chat template.
5251
pub fn new(template: String, bos: String, eos: String) -> Self {
5352
static NEXT: AtomicUsize = AtomicUsize::new(0);
5453
let id = NEXT.fetch_add(1, Relaxed).to_string();
5554

56-
jinja()
55+
JINJA_ENV
5756
.write()
5857
.unwrap()
5958
.add_template_owned(id.clone(), template)
@@ -76,7 +75,7 @@ impl ChatTemplate {
7675
add_generation_prompt: bool,
7776
}
7877

79-
jinja()
78+
JINJA_ENV
8079
.read()
8180
.unwrap()
8281
.get_template(&self.id)
@@ -92,26 +91,23 @@ impl ChatTemplate {
9291

9392
impl Drop for ChatTemplate {
9493
fn drop(&mut self) {
95-
jinja().write().unwrap().remove_template(&self.id);
94+
JINJA_ENV.write().unwrap().remove_template(&self.id);
9695
}
9796
}
9897

99-
fn jinja() -> &'static RwLock<Environment<'static>> {
100-
static ENV: LazyLock<RwLock<Environment<'_>>> = LazyLock::new(|| {
101-
let mut env = Environment::empty();
102-
env.set_unknown_method_callback(|_, value, method, args| {
103-
use minijinja::{value::ValueKind as ThisType, ErrorKind::UnknownMethod, Value};
104-
match (method, value.kind(), args) {
105-
("strip", ThisType::String, []) => Ok(Value::from_safe_string(
106-
value.to_str().unwrap().trim().into(),
107-
)),
108-
_ => Err(UnknownMethod.into()),
109-
}
110-
});
111-
RwLock::new(env)
98+
static JINJA_ENV: LazyLock<RwLock<Environment<'_>>> = LazyLock::new(|| {
99+
let mut env = Environment::empty();
100+
env.set_unknown_method_callback(|_, value, method, args| {
101+
use minijinja::{value::ValueKind as ThisType, ErrorKind::UnknownMethod, Value};
102+
match (method, value.kind(), args) {
103+
("strip", ThisType::String, []) => Ok(Value::from_safe_string(
104+
value.to_str().unwrap().trim().into(),
105+
)),
106+
_ => Err(UnknownMethod.into()),
107+
}
112108
});
113-
&ENV
114-
}
109+
RwLock::new(env)
110+
});
115111

116112
#[test]
117113
fn test() {

causal-lm/src/lib.rs

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
11
#![doc = include_str!("../README.md")]
2-
#![deny(warnings, missing_docs)]
2+
// #![deny(warnings, missing_docs)]
33

4+
mod chat_template;
45
mod decoding;
56
mod query_context;
6-
mod render;
7-
mod tokenize;
7+
mod tokenizer;
88

99
use common::{upos, utok};
1010
use digit_layout::types::U32;
11-
use std::{path::Path, time::Duration};
11+
use std::{io::Write, path::Path};
1212
use tensor::{udim, Tensor};
1313

14+
pub use chat_template::ChatTemplate;
1415
pub use decoding::DecodingMeta;
1516
pub use operators::random_sample::SampleArgs;
1617
pub use query_context::QueryContext;
17-
pub use render::{build_render, ChatTemplate};
18-
pub use tokenize::{build_tokenize, Tokenize};
18+
pub use tokenizer::Tokenizer;
1919

2020
/// 从文件系统加载的模型。
2121
pub trait Model: Sized {
@@ -24,17 +24,17 @@ pub trait Model: Sized {
2424
/// 模型加载中可能的错误。
2525
type Error;
2626
/// 从文件系统加载模型。
27-
fn load(gguf: impl AsRef<Path>, meta: Self::Config) -> Result<FromGGuf<Self>, Self::Error>;
27+
fn load(gguf: impl AsRef<Path>, config: Self::Config) -> Result<FromGGuf<Self>, Self::Error>;
2828
}
2929

3030
/// 从 GGuf 文件加载模型、分词器和渲染模板。
3131
pub struct FromGGuf<M: Model> {
3232
/// 模型。
3333
pub model: M,
3434
/// 分词器。
35-
pub tokenize: Box<dyn Tokenize>,
35+
pub tokenizer: Tokenizer,
3636
/// 渲染模板。
37-
pub render: Option<ChatTemplate>,
37+
pub chat_template: Option<ChatTemplate>,
3838
}
3939

4040
/// 因果语言模型。
@@ -119,32 +119,38 @@ pub fn pos<'a, S: 'a>(
119119
}
120120

121121
/// 测试模型实现。
122-
pub fn test_impl<M>(meta: M::Config, prompt: &[utok])
122+
pub fn test_impl<M>(meta: M::Config, max_steps: usize, prompt: &str)
123123
where
124124
M: CausalLM,
125125
M::Error: std::fmt::Debug,
126126
{
127-
use std::time::Instant;
127+
use std::time::{Duration, Instant};
128128

129129
let Some(gguf) = common::test_model::find() else {
130130
return;
131131
};
132132
println!("model: {}", gguf.display());
133133

134-
let t0 = Instant::now();
135-
let FromGGuf { model, .. } = M::load(gguf, meta).unwrap();
136-
let t1 = Instant::now();
137-
println!("load {:?}", t1 - t0);
134+
let time = Instant::now();
135+
let FromGGuf {
136+
model, tokenizer, ..
137+
} = M::load(gguf, meta).unwrap();
138+
println!("load {:?}", time.elapsed());
138139

139-
let mut cache = model.new_cache();
140+
let mut prompt = tokenizer.encode(prompt);
141+
print!("prompt:");
142+
for t in &prompt {
143+
print!(" {t}");
144+
}
140145

141-
let mut prompt = prompt.to_vec();
146+
let mut tokens = prompt.clone();
142147
let mut pos = 0;
143148

144149
let mut time = Duration::ZERO;
145150
let mut steps = 0;
146151

147-
while prompt != [model.eos_token()] {
152+
let mut cache = model.new_cache();
153+
while prompt != [model.eos_token()] && steps <= max_steps {
148154
let start = Instant::now();
149155

150156
let token_embedded = CausalLM::token_embed(&model, prompt.iter().copied());
@@ -165,21 +171,33 @@ where
165171
num_decode: 1,
166172
args: SampleArgs::ARG_MAX,
167173
}];
168-
let tokens = CausalLM::sample(&model, args, logits);
174+
let token = CausalLM::sample(&model, args, logits)[0];
169175

170176
if steps > 0 {
171177
time += start.elapsed();
172178
}
173179
steps += 1;
174180

175-
println!("{:?}", tokens);
181+
print!(" {token}");
182+
std::io::stdout().flush().unwrap();
183+
176184
pos += prompt.len() as upos;
177-
prompt = tokens;
185+
prompt.clear();
186+
prompt.push(token);
187+
tokens.push(token);
178188
}
179189

180190
steps -= 1;
191+
println!();
181192
println!(
182193
"steps = {steps}, average decoding time = {:?}",
183194
time.div_f32(steps as _)
184195
);
196+
println!();
197+
println!("---");
198+
for t in tokens {
199+
print!("{}", tokenizer.decode(t));
200+
}
201+
println!();
202+
println!("---");
185203
}

causal-lm/src/tokenize.rs

Lines changed: 0 additions & 132 deletions
This file was deleted.

0 commit comments

Comments
 (0)