Skip to content

Commit 57df781

Browse files
committed
refactor(llama.cu): 剥离分批策略
Signed-off-by: YdrMaster <ydrml@hotmail.com>
1 parent 06c78ad commit 57df781

File tree

12 files changed

+292
-238
lines changed

12 files changed

+292
-238
lines changed

llama.cu/src/batch/default.rs

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
use super::{BatchStrategy, Req, Round, SessionId, SessionStub};
2+
use log::warn;
3+
use std::{cmp::min, collections::BTreeMap, iter::repeat_n, mem::take};
4+
5+
pub(crate) struct DefaultStrategy<T> {
6+
sess: BTreeMap<SessionId, SessionStub<T>>,
7+
pre_output: BTreeMap<SessionId, usize>,
8+
// 每次prefill的最大长度
9+
chunked_prefill_max_len: Option<usize>,
10+
max_toks: usize,
11+
}
12+
13+
impl<T> DefaultStrategy<T> {
14+
pub fn new(chunked_prefill_len: Option<usize>, max_toks: usize) -> Self {
15+
Self {
16+
sess: Default::default(),
17+
pre_output: Default::default(),
18+
chunked_prefill_max_len: chunked_prefill_len,
19+
max_toks,
20+
}
21+
}
22+
}
23+
24+
impl<T: 'static + Clone> BatchStrategy<T> for DefaultStrategy<T> {
25+
fn is_empty(&self) -> bool {
26+
self.sess.is_empty()
27+
}
28+
29+
fn insert(&mut self, stub: SessionStub<T>) {
30+
assert!(self.sess.insert(stub.session.id, stub).is_none())
31+
}
32+
33+
fn remove(&mut self, id: &SessionId) -> Option<SessionStub<T>> {
34+
self.sess.remove(id)
35+
}
36+
37+
fn prepare(&mut self) -> Round<T> {
38+
let mut ans = Round::default();
39+
let mut out_idx = 0;
40+
41+
let pre_output = take(&mut self.pre_output);
42+
43+
let mut write_back_sessions = BTreeMap::new();
44+
45+
while let Some((id, mut stub)) = self.sess.pop_first() {
46+
let max = stub.session.cache.capacity;
47+
let pos = stub.session.cache.len;
48+
let mut seq = stub.state.seq;
49+
let mut out = stub.state.out;
50+
let mut end = pos + seq;
51+
assert_eq!(out, 1, "TODO: 投机采样");
52+
//验证缓存是否溢出
53+
if end > max {
54+
warn!("cache overflow {end} > {max}");
55+
// 缓存溢出,不再推理
56+
ans.overflow.push(stub.session);
57+
continue;
58+
}
59+
60+
// 用于限制每次tokens总数
61+
let remain_tok_num = self.max_toks - ans.tokens.len();
62+
assert!(remain_tok_num > 0);
63+
64+
if let Some(prompt) = &stub.prompt {
65+
seq = self
66+
.chunked_prefill_max_len
67+
.map_or(min(remain_tok_num, seq), |chunked_prefill_max_len| {
68+
remain_tok_num.min(seq).min(chunked_prefill_max_len)
69+
});
70+
71+
if seq < stub.state.seq {
72+
// chunked prefill
73+
out = 0;
74+
end = pos + seq;
75+
76+
ans.tokens
77+
.extend(prompt.iter().skip(prompt.len() - stub.state.seq).take(seq));
78+
79+
//更新stub信息
80+
stub.state.seq -= seq;
81+
} else {
82+
// 正常prefill
83+
if seq != prompt.len() {
84+
log::debug!("{:?} chunked prefil finished", id);
85+
}
86+
ans.tokens.extend(prompt[prompt.len() - seq..].to_owned());
87+
88+
stub.state.seq = 1;
89+
stub.prompt = None;
90+
}
91+
} else {
92+
// decode
93+
assert_eq!(seq, 1);
94+
// fast embd
95+
ans.fast_map
96+
.push((pre_output[&id] as _, ans.tokens.len() as _));
97+
ans.tokens.push(0)
98+
}
99+
100+
// 尝试填充缓存
101+
stub.session.cache.len = end;
102+
// 填充推理信息
103+
ans.sample.extend(repeat_n(stub.session.sample_args, out));
104+
ans.output.push((id, out));
105+
ans.reqs.push(Req {
106+
cache: stub.session.cache.cache.clone(),
107+
pos,
108+
seq,
109+
});
110+
111+
//输出处理
112+
//不会溢出 因为 out <= 1
113+
stub.state.remain_steps -= out;
114+
if stub.state.remain_steps == 0 {
115+
// 生成结束
116+
ans.finished.push(stub.session)
117+
} else {
118+
// 回填
119+
assert!(write_back_sessions.insert(id, stub).is_none());
120+
if out != 0 {
121+
assert!(self.pre_output.insert(id, out_idx).is_none());
122+
}
123+
}
124+
out_idx += out;
125+
126+
// 如果剩余tokens总数等于0,则退出循环
127+
if self.max_toks == ans.tokens.len() {
128+
break;
129+
}
130+
}
131+
self.sess.append(&mut write_back_sessions);
132+
ans
133+
}
134+
135+
fn take_stubs(&mut self) -> Vec<SessionStub<T>> {
136+
take(&mut self.sess).into_values().collect()
137+
}
138+
}

llama.cu/src/batch/mod.rs

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
mod default;
2+
3+
use crate::SampleArgs;
4+
use tokeneer::utok;
5+
6+
pub(crate) use default::DefaultStrategy;
7+
8+
pub trait BatchStrategy<T: 'static> {
9+
fn is_empty(&self) -> bool;
10+
fn insert(&mut self, stub: SessionStub<T>);
11+
fn remove(&mut self, id: &SessionId) -> Option<SessionStub<T>>;
12+
fn prepare(&mut self) -> Round<T>;
13+
fn take_stubs(&mut self) -> Vec<SessionStub<T>>;
14+
}
15+
16+
// 目前在有prompt的情况下,state.seq 的长度代表prompt还有多少未prefill,也就是 `prompt[prompt.len() - state.seq..]` 代表未prefill的prompt
17+
pub(super) struct SessionStub<T> {
18+
pub session: Session<T>,
19+
pub state: State,
20+
pub prompt: Option<Box<[utok]>>,
21+
}
22+
23+
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
24+
#[repr(transparent)]
25+
pub struct SessionId(pub usize);
26+
27+
pub struct Round<T> {
28+
pub overflow: Vec<Session<T>>,
29+
pub tokens: Vec<utok>,
30+
pub reqs: Vec<Req<T>>,
31+
pub sample: Vec<SampleArgs>,
32+
pub output: Vec<(SessionId, usize)>,
33+
pub fast_map: Vec<(utok, utok)>,
34+
pub finished: Vec<Session<T>>,
35+
}
36+
37+
impl<T> Default for Round<T> {
38+
fn default() -> Self {
39+
Self {
40+
overflow: Default::default(),
41+
tokens: Default::default(),
42+
reqs: Default::default(),
43+
sample: Default::default(),
44+
output: Default::default(),
45+
fast_map: Default::default(),
46+
finished: Default::default(),
47+
}
48+
}
49+
}
50+
51+
pub struct Session<T> {
52+
pub id: SessionId,
53+
pub sample_args: SampleArgs,
54+
pub cache: Cache<T>,
55+
}
56+
57+
pub struct Cache<T> {
58+
pub cache: T,
59+
pub capacity: usize,
60+
pub len: usize,
61+
}
62+
63+
#[derive(Clone, Copy)]
64+
pub(super) struct State {
65+
pub seq: usize,
66+
pub out: usize,
67+
pub remain_steps: usize,
68+
}
69+
70+
#[derive(Clone)]
71+
pub(crate) struct Req<Cache> {
72+
pub cache: Cache,
73+
pub pos: usize,
74+
pub seq: usize,
75+
}

llama.cu/src/exec/engine.rs

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
use super::{
2-
Command, Output, Request, Session,
3-
engine_manager::{EngineManager, Round},
4-
group::{ModelGroup, Req},
5-
kv_cache::KVCache,
2+
Command, Output, Request, engine_manager::EngineManager, group::ModelGroup,
63
output_head::OutputHead,
74
};
85
use crate::{
6+
CacheParts,
7+
batch::{Req, Round, SessionStub, State},
98
exec::{group::ModelGroupConfig, upos},
109
handle::Handle,
1110
op::{FastEmbedding, random_sample::KVPair},
@@ -23,7 +22,7 @@ use std::{
2322
num::NonZeroUsize,
2423
ops::Deref,
2524
sync::{
26-
Arc, Barrier, Mutex, OnceLock, RwLock,
25+
Arc, Barrier, OnceLock, RwLock,
2726
atomic::AtomicUsize,
2827
mpsc::{Receiver, Sender},
2928
},
@@ -33,22 +32,10 @@ use tokeneer::utok;
3332
#[cfg(nccl)]
3433
use operators::nccl::{Communicator, CommunicatorGroup};
3534

36-
// 目前在有prompt的情况下,state.seq 的长度代表prompt还有多少未prefill,也就是 `prompt[prompt.len() - state.seq..]` 代表未prefill的prompt
37-
pub(super) struct SessionStub {
38-
pub session: Session,
39-
pub state: State,
40-
pub prompt: Option<Box<[utok]>>,
41-
}
42-
43-
#[derive(Clone, Copy)]
44-
pub(super) struct State {
45-
pub seq: usize,
46-
pub out: usize,
47-
pub remain_steps: usize,
48-
}
35+
type Stub = SessionStub<CacheParts>;
4936

5037
impl Request {
51-
pub(super) fn into_stub(self) -> SessionStub {
38+
pub(super) fn into_stub(self) -> Stub {
5239
let Request {
5340
session,
5441
prompt,
@@ -206,7 +193,7 @@ type TaskBox = Arc<RwLock<Option<Task>>>;
206193
#[cfg_attr(not(nccl), allow(dead_code))]
207194
struct Task {
208195
key: NonZeroUsize,
209-
reqs: Vec<Req<Arc<[Mutex<KVCache>]>>>,
196+
reqs: Vec<Req<CacheParts>>,
210197
}
211198

212199
impl<T: IntoIterator<Item = usize>> Worker<T> {

0 commit comments

Comments
 (0)