Skip to content

Commit b11aa15

Browse files
authored
Merge pull request #37 from pwhMass/fix_chunked_prefill
fix: Fix chunked prefill
2 parents 88e1323 + a9b70da commit b11aa15

File tree

3 files changed

+85
-81
lines changed

3 files changed

+85
-81
lines changed

llama.cu/src/exec/engine.rs

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,14 @@ impl Request {
6666
}
6767
}
6868

69+
// 用于测试chunked prefill和动态建图
70+
// const NTOKS: [usize; 2] = [1, 4];
71+
// const CHUNKED_PREFILL_LEN: Option<usize> = Some(2);
72+
// const MAX_TOKS: usize = 8;
73+
74+
// TODO这些参数可能需要优化,目前是根据经验设置的
6975
const NTOKS: [usize; 7] = [1, 8, 32, 64, 128, 256, 1024];
70-
const CHUNKED_PREFILL_LEN: Option<usize> = Some(32);
71-
//TODO 该常量应该放在哪比较合适
76+
const CHUNKED_PREFILL_LEN: Option<usize> = Some(256);
7277
const MAX_TOKS: usize = 1024;
7378

7479
pub(crate) fn engine(
@@ -203,7 +208,7 @@ impl<T: IntoIterator<Item = usize>> Worker<T> {
203208
let gpu = Gpu::new(dev.retain_primary(), Default::default());
204209
let attn = Attn::new(&gpu);
205210
gpu.apply(|ctx| {
206-
let mut manager = EngineManager::new(chunked_prefill_len);
211+
let mut manager = EngineManager::new(chunked_prefill_len, max_toks);
207212
let mut handle = handle(ctx);
208213
let mut models =
209214
ModelGroup::new(llama, dist, config, attn, &mut handle, barrier.as_deref());
@@ -283,18 +288,10 @@ impl<T: IntoIterator<Item = usize>> Worker<T> {
283288

284289
// 如果没有输出,则跳过
285290
if !out_idx.is_empty() {
286-
let (output, sample): (Vec<_>, Vec<_>) = output
287-
.iter()
288-
.zip(sample.iter())
289-
.filter_map(|((id, len), sample_arg)| {
290-
if *len > 0 {
291-
Some(((*id, *len), sample_arg))
292-
} else {
293-
None
294-
}
295-
})
296-
.unzip();
297-
291+
let output = output
292+
.into_iter()
293+
.filter_map(|(id, len)| if len > 0 { Some((id, len)) } else { None })
294+
.collect::<Vec<_>>();
298295
let kv_pairs = output_head.launch(
299296
x,
300297
&out_idx_buf[..out_idx.len()],

llama.cu/src/exec/engine_manager.rs

Lines changed: 65 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
use crate::{exec::KVCache, op::random_sample::SampleArgs};
33
use log::warn;
44
use std::{
5+
cmp::min,
56
collections::BTreeMap,
67
iter::repeat_n,
78
mem::take,
@@ -12,12 +13,12 @@ use std::{
1213
};
1314
use tokeneer::utok;
1415

15-
#[derive(Default)]
1616
pub(super) struct EngineManager {
1717
sess: BTreeMap<SessionId, SessionStub>,
1818
pre_output: BTreeMap<SessionId, usize>,
1919
// 每次prefill的最大长度
20-
chunked_prefill_len: Option<usize>,
20+
chunked_prefill_max_len: Option<usize>,
21+
max_toks: usize,
2122
}
2223

2324
#[derive(Default)]
@@ -41,10 +42,12 @@ pub enum CommandError {
4142
type E = CommandError;
4243

4344
impl EngineManager {
44-
pub fn new(chunked_prefill_len: Option<usize>) -> Self {
45+
pub fn new(chunked_prefill_len: Option<usize>, max_toks: usize) -> Self {
4546
Self {
46-
chunked_prefill_len,
47-
..Default::default()
47+
sess: Default::default(),
48+
pre_output: Default::default(),
49+
chunked_prefill_max_len: chunked_prefill_len,
50+
max_toks,
4851
}
4952
}
5053
/// 接收命令
@@ -82,53 +85,62 @@ impl EngineManager {
8285
let mut out_idx = 0;
8386

8487
let pre_output = take(&mut self.pre_output);
85-
for (id, mut stub) in take(&mut self.sess) {
88+
89+
let mut write_back_sessions = BTreeMap::<SessionId, SessionStub>::new();
90+
91+
while let Some((id, mut stub)) = self.sess.pop_first() {
8692
let max = stub.session.cache.len;
8793
let pos = stub.session.cache.pos;
88-
let seq = stub.state.seq;
89-
let out = stub.state.out;
90-
let end = pos + seq;
91-
assert_eq!(out, 1, "TODO: ???");
94+
let mut seq = stub.state.seq;
95+
let mut out = stub.state.out;
96+
let mut end = pos + seq;
97+
assert_eq!(out, 1, "TODO: 投机采样");
9298
//验证缓存是否溢出
9399
if end > max {
94100
warn!("cache overflow {end} > {max}");
95101
// 缓存溢出,不再推理
96102
ans.overflow.push(stub.session);
97103
continue;
98104
}
99-
//chunked prefill
100-
// 采用 state.seq 用于计算剩余需要prefill的长度
105+
106+
// 用于限制每次tokens总数
107+
let remain_tok_num = self.max_toks - ans.tokens.len();
108+
assert!(remain_tok_num > 0);
109+
101110
if let Some(prompt) = &stub.prompt {
102-
if let Some(chunked_prefill_len) = self.chunked_prefill_len {
103-
if stub.state.seq > chunked_prefill_len {
104-
// 根据chunked_prefill_len重新计算seq和end
105-
let seq = chunked_prefill_len;
106-
let end = pos + seq;
107-
ans.sample.push(stub.session.sample_args);
108-
ans.output.push((id, 0));
109-
ans.reqs.push(Req {
110-
kv_cache: stub.session.cache.parts.clone(),
111-
pos,
112-
seq,
113-
});
114-
ans.tokens.extend(
115-
prompt
116-
.iter()
117-
.skip(prompt.len() - stub.state.seq)
118-
.take(chunked_prefill_len),
119-
);
120-
121-
//更新stub信息
122-
stub.session.cache.pos = end;
123-
stub.state.seq -= chunked_prefill_len;
124-
125-
//回填
126-
assert!(self.sess.insert(id, stub).is_none());
127-
128-
//提前结束
129-
continue;
111+
seq = self
112+
.chunked_prefill_max_len
113+
.map_or(min(remain_tok_num, seq), |chunked_prefill_max_len| {
114+
remain_tok_num.min(seq).min(chunked_prefill_max_len)
115+
});
116+
117+
if seq < stub.state.seq {
118+
// chunked prefill
119+
out = 0;
120+
end = pos + seq;
121+
122+
ans.tokens
123+
.extend(prompt.iter().skip(prompt.len() - stub.state.seq).take(seq));
124+
125+
//更新stub信息
126+
stub.state.seq -= seq;
127+
} else {
128+
// 正常prefill
129+
if seq != prompt.len() {
130+
log::debug!("{:?} chunked prefil finished", id);
130131
}
132+
ans.tokens.extend(prompt[prompt.len() - seq..].to_owned());
133+
134+
stub.state.seq = 1;
135+
stub.prompt = None;
131136
}
137+
} else {
138+
// decode
139+
assert_eq!(seq, 1);
140+
// fast embd
141+
ans.fast_map
142+
.push((pre_output[&id] as _, ans.tokens.len() as _));
143+
ans.tokens.push(0)
132144
}
133145

134146
// 尝试填充缓存
@@ -141,35 +153,28 @@ impl EngineManager {
141153
pos,
142154
seq,
143155
});
144-
if let Some(prompt) = stub.prompt.take() {
145-
// prefill
146-
if seq != prompt.len() {
147-
log::debug!("{:?} chunked prefil finished", id);
148-
}
149-
ans.tokens.extend(prompt[prompt.len() - seq..].to_owned());
150-
151-
stub.state.seq = 1
152-
} else {
153-
// decode
154-
assert_eq!(seq, 1);
155-
// fast embd
156-
ans.fast_map
157-
.push((pre_output[&id] as _, ans.tokens.len() as _));
158-
ans.tokens.push(0)
159-
}
160156

161157
//输出处理
162-
stub.state.remain_steps -= 1;
158+
//不会溢出 因为 out <= 1
159+
stub.state.remain_steps -= out;
163160
if stub.state.remain_steps == 0 {
164161
// 生成结束
165162
ans.finished.push(stub.session)
166163
} else {
167164
// 回填
168-
assert!(self.sess.insert(id, stub).is_none());
169-
assert!(self.pre_output.insert(id, out_idx).is_none());
165+
assert!(write_back_sessions.insert(id, stub).is_none());
166+
if out != 0 {
167+
assert!(self.pre_output.insert(id, out_idx).is_none());
168+
}
169+
}
170+
out_idx += out;
171+
172+
// 如果剩余tokens总数等于0,则退出循环
173+
if self.max_toks == ans.tokens.len() {
174+
break;
170175
}
171-
out_idx += out
172176
}
177+
self.sess.append(&mut write_back_sessions);
173178
ans
174179
}
175180

xtask/src/service/mod.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -131,13 +131,15 @@ async fn start_infer_service(
131131
} else {
132132
&[]
133133
};
134-
if think.len() < tokens.len() {
135-
session_info.think = false;
136-
tokens = &tokens[think.len() + 1..]
137-
} else {
138-
tokens = &[]
139-
}
140134

135+
if session_info.think {
136+
if think.len() < tokens.len() {
137+
session_info.think = false;
138+
tokens = &tokens[think.len() + 1..]
139+
} else {
140+
tokens = &[]
141+
}
142+
}
141143
let think = service_manager_for_recv
142144
.terminal
143145
.decode(think, &mut session_info.buf);

0 commit comments

Comments
 (0)