Skip to content

Commit 5e8be04

Browse files
committed
refactor(llama.cu): 优化重复惩罚
Signed-off-by: YdrMaster <ydrml@hotmail.com>
1 parent f823b26 commit 5e8be04

File tree

5 files changed

+91
-61
lines changed

5 files changed

+91
-61
lines changed

llama.cu/src/exec/engine.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,6 @@ impl<T: IntoIterator<Item = usize>> Worker<T> {
225225
let gpu = Gpu::new(dev.retain_primary(), Default::default());
226226
let attn = Attn::new(&gpu);
227227
gpu.apply(|ctx| {
228-
let mut manager = EngineManager::new(chunked_prefill_len, max_toks);
229228
let mut handle = handle(ctx);
230229
let mut models = ModelGroup::new(
231230
llama,
@@ -237,6 +236,7 @@ impl<T: IntoIterator<Item = usize>> Worker<T> {
237236
barrier.as_deref(),
238237
);
239238

239+
let mut manager = EngineManager::new(chunked_prefill_len, max_toks);
240240
let mut output_head = OutputHead::new(output_head, ctx);
241241
let mut sample_manager = SampleManager::new(output_head.nvoc(), eos, ctx);
242242

@@ -253,7 +253,9 @@ impl<T: IntoIterator<Item = usize>> Worker<T> {
253253
let mut out_idx_buf = BufN::<utok>::new(len, BUF_LEVEL, ctx);
254254
let mut fast_embd_buf = BufN::<(utok, utok)>::new(len, BUF_LEVEL, ctx);
255255
if outputs.send(Output::Ready).is_ok() {
256-
while manager.receive(&commands, &outputs).is_ok() {
256+
while let Ok(removed) = manager.receive(&commands, &outputs) {
257+
// 处理已移除会话
258+
sample_manager.remove(removed);
257259
// 组织请求
258260
let Round {
259261
overflow,
@@ -264,11 +266,14 @@ impl<T: IntoIterator<Item = usize>> Worker<T> {
264266
fast_map,
265267
finished,
266268
} = manager.prepare();
269+
// 处理缓存溢出
270+
sample_manager.remove(overflow.iter().map(|s| s.id));
267271
if !overflow.is_empty()
268272
&& outputs.send(Output::Overflow(overflow.into())).is_err()
269273
{
270274
break;
271275
}
276+
// 如果不需要推理
272277
if tokens.is_empty() {
273278
assert!(
274279
reqs.is_empty()
@@ -279,6 +284,7 @@ impl<T: IntoIterator<Item = usize>> Worker<T> {
279284
);
280285
continue;
281286
}
287+
// 更新 host 多级缓存
282288
let out_idx = out_idx(&reqs, output.iter().map(|(_, len)| *len));
283289
events[out_idx_buf.index()].synchronize();
284290
tok_buf.save(&tokens);
@@ -322,6 +328,8 @@ impl<T: IntoIterator<Item = usize>> Worker<T> {
322328
let kv_pairs = sample_manager.sample(logits, &input, &sample, &stream);
323329
stream.free(input);
324330
stream.memcpy_d2d(&mut pre_kv_pairs[..kv_pairs.len()], &kv_pairs);
331+
// 处理推理结束
332+
sample_manager.remove(finished.iter().map(|s| s.id));
325333
// 生成并发送输出
326334
let output = output
327335
.into_iter()

llama.cu/src/exec/engine_manager.rs

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
use super::{Command, Output};
22
use crate::{
3-
CacheParts,
3+
CacheParts, SessionId,
44
batch::{BatchStrategy, DefaultStrategy, Round, SessionStub},
55
};
6-
use std::sync::mpsc::{Receiver, Sender, TryRecvError};
6+
use std::{
7+
collections::BTreeSet,
8+
sync::mpsc::{Receiver, Sender, TryRecvError},
9+
};
710

811
pub(super) struct EngineManager(DefaultStrategy<CacheParts>);
912

@@ -26,12 +29,13 @@ impl EngineManager {
2629
&mut self,
2730
commands: &Receiver<Command>,
2831
outputs: &Sender<Output>,
29-
) -> Result<(), E> {
32+
) -> Result<BTreeSet<SessionId>, E> {
33+
let mut removed = BTreeSet::new();
3034
loop {
3135
// 总是尝试进行非阻塞接收
3236
loop {
3337
match commands.try_recv() {
34-
Ok(cmd) => self.apply(cmd, outputs)?,
38+
Ok(cmd) => self.apply(cmd, outputs, &mut removed)?,
3539
Err(TryRecvError::Disconnected) => return Err(E::ReceiveError),
3640
Err(TryRecvError::Empty) => break,
3741
}
@@ -40,14 +44,15 @@ impl EngineManager {
4044
if self.0.is_empty() {
4145
// 也没有待处理的任务,阻塞等待
4246
match commands.recv() {
43-
Ok(cmd) => self.apply(cmd, outputs)?,
44-
Err(_) => break Err(E::ReceiveError),
47+
Ok(cmd) => self.apply(cmd, outputs, &mut removed)?,
48+
Err(_) => return Err(E::ReceiveError),
4549
}
4650
} else {
4751
// 有待处理的任务,退出循环
48-
break Ok(());
52+
break;
4953
}
5054
}
55+
Ok(removed)
5156
}
5257

5358
/// 准备推理
@@ -59,22 +64,26 @@ impl EngineManager {
5964
self.0.take_stubs()
6065
}
6166

62-
fn apply(&mut self, cmd: Command, outputs: &Sender<Output>) -> Result<(), CommandError> {
67+
fn apply(
68+
&mut self,
69+
cmd: Command,
70+
outputs: &Sender<Output>,
71+
removed: &mut BTreeSet<SessionId>,
72+
) -> Result<(), CommandError> {
6373
match cmd {
6474
Command::ShutDown => Err(CommandError::ShutDown),
6575
Command::Insert(req) => {
6676
self.0.insert(req.into_stub());
6777
Ok(())
6878
}
6979
Command::Remove(id) => {
70-
if self
71-
.0
72-
.remove(&id)
73-
.is_none_or(|stub| outputs.send(Output::Removed(stub.session)).is_ok())
74-
{
75-
Ok(())
80+
if let Some(stub) = self.0.remove(&id) {
81+
removed.insert(stub.session.id);
82+
outputs
83+
.send(Output::Removed(stub.session))
84+
.map_err(|_| CommandError::SendError)
7685
} else {
77-
Err(CommandError::SendError)
86+
Ok(())
7887
}
7988
}
8089
}

llama.cu/src/exec/sample_manager.rs

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ impl<'ctx> SampleManager<'ctx> {
3939
let logits = logits_.as_mut().map(|mem| mem.as_ptr().cast());
4040
dims!([out_len, _nvoc] = logits);
4141

42+
let kv_pair_template = Tensor::from_dim_slice(KV_PAIR, []);
4243
let kv_pair = stream.malloc::<KVPair>(out_len);
4344
for (i, (id, info)) in config.into_iter().enumerate() {
4445
let logits = logits.clone().transform(|layout| layout.index(0, i));
@@ -48,7 +49,7 @@ impl<'ctx> SampleManager<'ctx> {
4849
decode_len,
4950
} = info;
5051

51-
let scale = state
52+
let state = state
5253
.entry(*id)
5354
.or_insert_with(|| modifier.new_state(stream));
5455
let tok = if *decode_len == 0 {
@@ -60,23 +61,30 @@ impl<'ctx> SampleManager<'ctx> {
6061
unsafe {
6162
modifier.next(
6263
&logits,
63-
scale.as_mut_ptr(),
64+
state.as_mut_ptr(),
6465
tok,
6566
args.temperature,
6667
args.repetition_penalty,
6768
stream,
6869
)
69-
};
70+
}
7071

71-
let kv_pair = Tensor::from_dim_slice(KV_PAIR, [])
72+
let kv_pair = kv_pair_template
73+
.as_ref()
7274
.map(|_| kv_pair[i * size_of::<KVPair>()..].as_ptr().cast());
7375
if args.is_argmax() {
74-
sample.argmax(kv_pair.clone(), logits, stream)
76+
sample.argmax(kv_pair, logits, stream)
7577
} else {
76-
sample.sample(kv_pair.clone(), logits, *args, rand::random(), stream)
78+
sample.sample(kv_pair, logits, *args, rand::random(), stream)
7779
}
7880
}
7981
stream.free(logits_.take());
8082
kv_pair
8183
}
84+
85+
pub fn remove(&mut self, id: impl IntoIterator<Item = SessionId>) {
86+
for id in id {
87+
self.state.remove(&id);
88+
}
89+
}
8290
}

llama.cu/src/op/random_sample/modifier.rs

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
//! <https://zhuanlan.zhihu.com/p/667025336>
22
33
use crate::utils::offset_ptr;
4-
use ggus::ggml_quants::f16;
54
use log::warn;
65
use nn::Tensor;
76
use operators::cuda::{CurrentCtx, DevByte, DevMem, Module, Ptx, Stream, VirByte, params};
@@ -26,13 +25,13 @@ impl<'ctx> LogitsModifier<'ctx> {
2625

2726
impl LogitsModifier<'_> {
2827
pub fn new_state<'ctx>(&self, stream: &Stream<'ctx>) -> DevMem<'ctx> {
29-
stream.malloc::<f16>(self.n)
28+
stream.malloc::<u32>(self.n)
3029
}
3130

3231
pub unsafe fn next<const N: usize>(
3332
&self,
3433
logits: &Tensor<*const VirByte, N>,
35-
scale: *mut DevByte,
34+
records: *mut DevByte,
3635
tok: *const DevByte,
3736
mut temperature: f32,
3837
penalty: f32,
@@ -47,12 +46,12 @@ impl LogitsModifier<'_> {
4746
(n.div_ceil(256), 256, 0),
4847
&params![
4948
offset_ptr(logits),
50-
scale,
49+
records,
5150
n,
51+
self.eos,
5252
temperature,
53-
penalty.recip(),
54-
tok,
55-
self.eos
53+
penalty,
54+
tok
5655
]
5756
.to_ptrs(),
5857
);
@@ -61,19 +60,21 @@ impl LogitsModifier<'_> {
6160
fn compile<'ctx>(ctx: &'ctx CurrentCtx) -> Module<'ctx> {
6261
const CODE: &str = include_str!("modify.cuh");
6362
let code = format!(
64-
r#"
65-
{CODE}
63+
r#"{CODE}
6664
6765
extern "C" __global__ void next(
68-
half *logits,
69-
half *scale,
70-
unsigned int const n,
71-
float const temperature,
72-
float const penalty,
73-
unsigned int const *tok,
74-
unsigned int const eos
66+
// 采样分布和状态
67+
half *logits, // 概率分布
68+
unsigned int *records, // 每个 token 的出现次数
69+
// 词表信息
70+
unsigned int const n, // 词表长度
71+
unsigned int const eos, // 结束符
72+
// 采样参数
73+
float const temperature, // 温度
74+
float const penalty, // 重复惩罚
75+
unsigned int const *tok // 上一次采样结果
7576
) {{
76-
next_kernel(logits, scale, n, temperature, penalty, tok, eos);
77+
next_kernel(logits, records, n, eos, temperature, penalty, tok);
7778
}}"#
7879
);
7980
let (ptx, log) = Ptx::compile(code, ctx.dev().compute_capability());
Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,36 @@
11
template <typename T>
22
__global__ void next_kernel(
3-
T *logits,
4-
T *scale_,
5-
unsigned int const n,
6-
float const temperature,
7-
float const penalty,
8-
unsigned int const *tok,
9-
unsigned int const eos) {
3+
// 采样分布和状态
4+
T *logits, // 概率分布
5+
unsigned int *records, // 每个 token 的出现次数
6+
// 词表信息
7+
unsigned int const n, // 词表长度
8+
unsigned int const eos, // 结束符
9+
// 采样参数
10+
float const temperature, // 温度
11+
float const penalty, // 重复惩罚
12+
unsigned int const *tok // 上一次采样结果
13+
) {
1014
unsigned int const i = blockIdx.x * blockDim.x + threadIdx.x;
1115
if (i >= n) {
1216
return;
1317
}
14-
float scale;
18+
// 更新出现次数
1519
if (!tok) {
16-
// 初始化惩罚权重
17-
scale = i == eos ? 0 : 1;
18-
scale_[i] = 1;
19-
} else {
20-
// 更新惩罚权重
21-
scale = (float)scale_[i];
22-
if (i == *tok) {
23-
scale *= penalty;
24-
scale_[i] = scale;
25-
}
20+
records[i] = 0;
21+
} else if (i == *tok) {
22+
++records[i];
2623
}
27-
if (((float)logits[i]) > .0) {
28-
logits[i] *= (T)(temperature * scale);
24+
// 调整分布
25+
if (!tok && i == eos) {
26+
// 第一轮解码绝不产生 eos
27+
((unsigned int *)logits)[i] = 0xFF800000; // float -∞
2928
} else {
30-
logits[i] /= (T)(temperature * scale);
29+
T scale = temperature * powf(penalty, records[i]);
30+
if (((float)logits[i]) > .0) {
31+
logits[i] /= scale;
32+
} else {
33+
logits[i] *= scale;
34+
}
3135
}
3236
}

0 commit comments

Comments
 (0)