Skip to content

Commit 8140ee1

Browse files
committed
feat: 支持重复惩罚
Signed-off-by: YdrMaster <ydrml@hotmail.com>
1 parent 06f9de0 commit 8140ee1

File tree

15 files changed

+356
-107
lines changed

15 files changed

+356
-107
lines changed

Cargo.lock

Lines changed: 6 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

llama.cu/src/batch/default.rs

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
use super::{BatchStrategy, Req, Round, SessionId, SessionStub};
2+
use crate::batch::SampleInfo;
23
use log::warn;
3-
use std::{cmp::min, collections::BTreeMap, iter::repeat_n, mem::take};
4+
use std::{collections::BTreeMap, mem::take};
45

56
pub(crate) struct DefaultStrategy<T> {
67
sess: BTreeMap<SessionId, SessionStub<T>>,
78
pre_output: BTreeMap<SessionId, usize>,
89
// 每次 prefill 的最大长度
9-
chunked_prefill_max_len: Option<usize>,
10+
chunked_prefill_max_len: usize,
1011
max_toks: usize,
1112
}
1213

@@ -15,7 +16,7 @@ impl<T> DefaultStrategy<T> {
1516
Self {
1617
sess: Default::default(),
1718
pre_output: Default::default(),
18-
chunked_prefill_max_len: chunked_prefill_len,
19+
chunked_prefill_max_len: chunked_prefill_len.unwrap_or(usize::MAX),
1920
max_toks,
2021
}
2122
}
@@ -61,32 +62,27 @@ impl<T: 'static + Clone> BatchStrategy<T> for DefaultStrategy<T> {
6162
let remain_tok_num = self.max_toks - ans.tokens.len();
6263
assert!(remain_tok_num > 0);
6364

65+
let input_idx = ans.tokens.len();
6466
if let Some(prompt) = &stub.prompt {
65-
seq = self
66-
.chunked_prefill_max_len
67-
.map_or(min(remain_tok_num, seq), |chunked_prefill_max_len| {
68-
remain_tok_num.min(seq).min(chunked_prefill_max_len)
69-
});
67+
seq = self.chunked_prefill_max_len.min(seq).min(remain_tok_num);
68+
let (prompt, tail) = prompt[prompt.len() - stub.state.seq..].split_at(seq);
7069

71-
if seq < stub.state.seq {
72-
// chunked prefill
73-
out = 0;
74-
end = pos + seq;
75-
76-
ans.tokens
77-
.extend(prompt.iter().skip(prompt.len() - stub.state.seq).take(seq));
78-
79-
//更新stub信息
80-
stub.state.seq -= seq
81-
} else {
70+
if tail.is_empty() {
8271
// 正常 prefill
8372
if seq != prompt.len() {
8473
log::debug!("{id:?} chunked prefil finished")
8574
}
86-
ans.tokens.extend(prompt[prompt.len() - seq..].to_owned());
87-
75+
ans.tokens.extend(prompt);
76+
// 更新 stub 信息
8877
stub.state.seq = 1;
8978
stub.prompt = None
79+
} else {
80+
// chunked prefill
81+
out = 0;
82+
end = pos + seq;
83+
ans.tokens.extend(prompt);
84+
// 更新 stub 信息
85+
stub.state.seq = tail.len()
9086
}
9187
} else {
9288
// decode
@@ -100,25 +96,36 @@ impl<T: 'static + Clone> BatchStrategy<T> for DefaultStrategy<T> {
10096
// 尝试填充缓存
10197
stub.session.cache.len = end;
10298
// 填充推理信息
103-
ans.sample.extend(repeat_n(stub.session.sample_args, out));
99+
ans.sample
100+
.extend((input_idx..input_idx + out).map(|input_idx| {
101+
(
102+
id,
103+
SampleInfo {
104+
args: stub.session.sample_args,
105+
input_idx,
106+
decode_len: stub.state.decode_len,
107+
},
108+
)
109+
}));
104110
ans.output.push((id, out));
105111
ans.reqs.push(Req {
106112
cache: stub.session.cache.cache.clone(),
107113
pos,
108114
seq,
109115
});
116+
if out > 0 {
117+
stub.state.decode_len += 1
118+
}
110119

111120
// 输出处理
112-
// 不会溢出 因为 out <= 1
113-
stub.state.remain_steps -= out;
114-
if stub.state.remain_steps == 0 {
121+
if stub.state.decode_len == stub.state.max_steps {
115122
// 生成结束
116123
ans.finished.push(stub.session)
117124
} else {
118125
// 回填
119126
assert!(write_back_sessions.insert(id, stub).is_none());
120127
if out != 0 {
121-
assert!(self.pre_output.insert(id, out_idx).is_none());
128+
assert!(self.pre_output.insert(id, out_idx).is_none())
122129
}
123130
}
124131
out_idx += out;

llama.cu/src/batch/mod.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ pub struct Round<T> {
2828
pub overflow: Vec<Session<T>>,
2929
pub tokens: Vec<utok>,
3030
pub reqs: Vec<Req<T>>,
31-
pub sample: Vec<SampleArgs>,
31+
pub sample: Vec<(SessionId, SampleInfo)>,
3232
pub output: Vec<(SessionId, usize)>,
3333
pub fast_map: Vec<(utok, utok)>,
3434
pub finished: Vec<Session<T>>,
@@ -48,6 +48,13 @@ impl<T> Default for Round<T> {
4848
}
4949
}
5050

51+
#[derive(Clone, Copy)]
52+
pub struct SampleInfo {
53+
pub args: SampleArgs,
54+
pub input_idx: usize,
55+
pub decode_len: usize,
56+
}
57+
5158
pub struct Session<T> {
5259
pub id: SessionId,
5360
pub sample_args: SampleArgs,
@@ -64,7 +71,8 @@ pub struct Cache<T> {
6471
pub(super) struct State {
6572
pub seq: usize,
6673
pub out: usize,
67-
pub remain_steps: usize,
74+
pub decode_len: usize,
75+
pub max_steps: usize,
6876
}
6977

7078
#[derive(Clone)]

llama.cu/src/exec/engine.rs

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
use crate::{
66
CacheParts,
77
batch::{Req, Round, SessionStub, State},
8-
exec::{group::ModelGroupConfig, upos},
8+
exec::{group::ModelGroupConfig, sample_manager::SampleManager, upos},
99
handle::Handle,
1010
op::{FastEmbedding, random_sample::KVPair},
1111
};
@@ -47,7 +47,8 @@ impl Request {
4747
state: State {
4848
seq: prompt.len(),
4949
out,
50-
remain_steps: max_steps,
50+
decode_len: 0,
51+
max_steps,
5152
},
5253
prompt: Some(prompt),
5354
}
@@ -72,6 +73,7 @@ pub struct Progress {
7273

7374
pub(crate) fn engine(
7475
llama: LLaMA<Tensor<&[u8], 2>>,
76+
eos: utok,
7577
workers: &[(c_int, Option<Arc<Progress>>)],
7678
commands: Receiver<Command>,
7779
outputs: Sender<Output>,
@@ -80,6 +82,7 @@ pub(crate) fn engine(
8082
if let &[(gpu, progress)] = &workers {
8183
return mono(
8284
llama,
85+
eos,
8386
Device::new(*gpu),
8487
progress.clone(),
8588
commands,
@@ -146,6 +149,7 @@ pub(crate) fn engine(
146149

147150
fn mono(
148151
mut llama: LLaMA<Tensor<&[u8], 2>>,
152+
eos: utok,
149153
dev: Device,
150154
progress: Option<Arc<Progress>>,
151155
commands: Receiver<Command>,
@@ -171,7 +175,7 @@ fn mono(
171175
task_box: Default::default(),
172176
chunked_prefill_len: CHUNKED_PREFILL_LEN,
173177
}
174-
.lead(llama, output_head, commands, outputs, |ctx| {
178+
.lead(llama, eos, output_head, commands, outputs, |ctx| {
175179
Handle::new(ctx)
176180
})
177181
}
@@ -200,6 +204,7 @@ impl<T: IntoIterator<Item = usize>> Worker<T> {
200204
fn lead(
201205
self,
202206
llama: LLaMA<Tensor<&[u8], 2>>,
207+
eos: utok,
203208
output_head: nn::OutputHead<Tensor<&[u8], 2>>,
204209
commands: Receiver<Command>,
205210
outputs: Sender<Output>,
@@ -233,6 +238,7 @@ impl<T: IntoIterator<Item = usize>> Worker<T> {
233238
);
234239

235240
let mut output_head = OutputHead::new(output_head, ctx);
241+
let mut sample_manager = SampleManager::new(output_head.nvoc(), eos, ctx);
236242

237243
let max_tok = max_toks;
238244
let mut fast_embd = FastEmbedding::new(max_tok, ctx);
@@ -246,7 +252,6 @@ impl<T: IntoIterator<Item = usize>> Worker<T> {
246252
let mut pos_buf = BufN::<upos>::new(len, BUF_LEVEL, ctx);
247253
let mut out_idx_buf = BufN::<utok>::new(len, BUF_LEVEL, ctx);
248254
let mut fast_embd_buf = BufN::<(utok, utok)>::new(len, BUF_LEVEL, ctx);
249-
250255
if outputs.send(Output::Ready).is_ok() {
251256
while manager.receive(&commands, &outputs).is_ok() {
252257
// 组织请求
@@ -302,33 +307,34 @@ impl<T: IntoIterator<Item = usize>> Worker<T> {
302307
barrier.wait();
303308
models.share_inputs(key, &mut handle, &stream);
304309
}
310+
let mut input = stream.malloc::<utok>(tok.len() / size_of::<utok>());
311+
stream.memcpy_d2d(&mut input, tok);
305312
// 推理
306313
let x = models.launch(key, &reqs, &mut handle, &stream);
307-
308314
// 如果没有输出,则跳过
309-
if !out_idx.is_empty() {
310-
let output = output
311-
.into_iter()
312-
.filter_map(|(id, len)| if len > 0 { Some((id, len)) } else { None })
313-
.collect::<Vec<_>>();
314-
let kv_pairs = output_head.launch(
315-
x,
316-
&out_idx_buf[..out_idx.len()],
317-
sample,
318-
&mut handle,
319-
&stream,
320-
);
321-
stream.memcpy_d2d(&mut pre_kv_pairs[..kv_pairs.len()], &kv_pairs);
322-
323-
let output = Output::Complete {
324-
output: output.into(),
325-
kv_pair: kv_pairs.sporulate(),
326-
event: stream.record().sporulate(),
327-
finished: finished.into(),
328-
};
329-
if outputs.send(output).is_err() {
330-
break;
331-
}
315+
if out_idx.is_empty() {
316+
continue;
317+
}
318+
// 计算输出头
319+
let logits =
320+
output_head.launch(x, &out_idx_buf[..out_idx.len()], &mut handle, &stream);
321+
// 采样
322+
let kv_pairs = sample_manager.sample(logits, &input, &sample, &stream);
323+
stream.free(input);
324+
stream.memcpy_d2d(&mut pre_kv_pairs[..kv_pairs.len()], &kv_pairs);
325+
// 生成并发送输出
326+
let output = output
327+
.into_iter()
328+
.filter_map(|(id, len)| if len > 0 { Some((id, len)) } else { None })
329+
.collect();
330+
let output = Output::Complete {
331+
output,
332+
kv_pair: kv_pairs.sporulate(),
333+
event: stream.record().sporulate(),
334+
finished: finished.into(),
335+
};
336+
if outputs.send(output).is_err() {
337+
break;
332338
}
333339
}
334340
}

llama.cu/src/exec/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ mod group;
44
mod kv_cache;
55
mod model;
66
mod output_head;
7+
mod sample_manager;
78
mod step;
89

910
use crate::{

0 commit comments

Comments
 (0)