Skip to content

Commit 2896798

Browse files
committed
style(llama): 整理和优化 MOE 模型结构
Signed-off-by: YdrMaster <ydrml@hotmail.com>
1 parent e14c588 commit 2896798

File tree

3 files changed

+79
-72
lines changed

3 files changed

+79
-72
lines changed

gguf/src/lib.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,3 +130,25 @@ impl GGufMetaMap for GGufModel<'_> {
130130
self.meta_kvs.get(key).map(|kv| (kv.ty(), kv.value_bytes()))
131131
}
132132
}
133+
134+
mod macros {
135+
#[macro_export]
136+
macro_rules! meta {
137+
($gguf:expr => $key:ident) => {
138+
$gguf.$key().unwrap()
139+
};
140+
($gguf:expr => $key:ident; $default:expr) => {
141+
match $gguf.$key() {
142+
Ok(val) => val,
143+
Err(gguf::GGufMetaError::NotExist) => $default,
144+
Err(e) => panic!("failed to read meta: {e:?}"),
145+
}
146+
};
147+
}
148+
#[macro_export]
149+
macro_rules! tensor {
150+
($gguf:expr => $name:expr) => {
151+
&$gguf.tensors[&*$name]
152+
};
153+
}
154+
}

models/llama/common/src/compute.rs

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,6 @@ where
195195
nh,
196196
nkvh,
197197
nexp,
198-
nexp_use,
199198
dh,
200199
di,
201200
..
@@ -230,6 +229,7 @@ where
230229
});
231230

232231
let req_split = requests.iter().map(|req| req.seq_len).collect::<Vec<_>>();
232+
let tok_split = vec![1; nt];
233233

234234
let queue = queue_alloc.queue();
235235
for iblk in 0..nblk {
@@ -324,38 +324,19 @@ where
324324

325325
Ops::memcpy_d2h(&mut routes_host, routes_dev.get(), queue)
326326
}
327-
let ([], routes, []) = (unsafe { routes_host.align_to_mut::<f16>() }) else {
327+
let ([], mut routes, []) = (unsafe { routes_host.align_to::<f16>() }) else {
328328
unreachable!()
329329
};
330330

331-
for itok in (0..nt).rev() {
332-
// fused topk
333-
let mut routes = routes[itok * nexp..][..nexp]
334-
.iter()
335-
.copied()
336-
.enumerate()
337-
.collect::<Vec<_>>();
338-
339-
routes.sort_unstable_by(|&(_, a), &(_, b)| b.total_cmp(&a));
340-
let max = routes[0].1.to_f32();
341-
let mut sum = 0.;
342-
let mut moe_gate = vec![(0, 0.0f32); nexp_use];
343-
for ((i, x), gate) in std::iter::zip(routes, &mut moe_gate) {
344-
let softmax = (x.to_f32() - max).exp();
345-
*gate = (i, softmax);
346-
sum += softmax
347-
}
348-
for (_, x) in &mut moe_gate {
349-
*x /= sum
350-
}
351-
// mlp
352-
let (buf, workspace) = workspace.split_at_mut(*gate_up.get());
353-
let mut gate_up = gate_up.clone().map(|_| buf);
354-
355-
let mut x = x.map_slice_mut().slice(0, itok, 0, 1);
356-
let x1 = x1.map_slice_mut().slice(0, itok, 0, 1);
331+
let (buf, workspace) = workspace.split_at_mut(*gate_up.get());
332+
let mut gate_up = gate_up.clone().map(|_| buf);
357333

358-
for (iexp, kexp) in moe_gate {
334+
let x = x.split(0, &tok_split);
335+
let x1 = x1.split(0, &tok_split);
336+
for (mut x, x1) in izip!(x, x1) {
337+
let (line, tail) = routes.split_at(nexp);
338+
routes = tail;
339+
for (iexp, kexp) in self.topk_with_index(line) {
359340
let w = self.weights.ffn_gate_up(iblk, iexp, queue);
360341
self.mat_mul(&mut gate_up, 0., &x1, &w, 1., workspace, queue_alloc)?;
361342
drop(w);
@@ -409,6 +390,27 @@ where
409390
Ops: Operators,
410391
W: WeightLoader<Hardware = Ops::Hardware>,
411392
{
393+
fn topk_with_index(&self, line: &[f16]) -> Vec<(usize, f32)> {
394+
let mut routes = line
395+
.iter()
396+
.map(|&x| x.to_f32())
397+
.enumerate()
398+
.collect::<Vec<_>>();
399+
routes.sort_unstable_by(|&(_, a), &(_, b)| b.total_cmp(&a));
400+
routes.truncate(self.meta.nexp_use);
401+
// standard softmax
402+
let (_, max) = routes[0];
403+
let mut sum = 0.;
404+
for (_, x) in &mut routes {
405+
*x = (*x - max).exp();
406+
sum += *x
407+
}
408+
for (_, x) in &mut routes {
409+
*x /= sum
410+
}
411+
routes
412+
}
413+
412414
fn rms_norm<Y, X, W_, QA>(
413415
&self,
414416
y: &mut Tensor<Y>,

models/llama/common/src/storage.rs

Lines changed: 26 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -26,61 +26,44 @@ pub struct BlkStorage<T> {
2626

2727
impl<'a> Storage<&'a [u8]> {
2828
pub fn from_gguf(gguf: &GGufModel<'a>) -> Self {
29-
macro_rules! meta {
30-
($key:ident) => {
31-
gguf.$key().unwrap()
32-
};
33-
($key:ident; $default:expr) => {
34-
match gguf.$key() {
35-
Ok(val) => val,
36-
Err(gguf::GGufMetaError::NotExist) => $default,
37-
Err(e) => panic!("failed to read meta: {e:?}"),
38-
}
39-
};
40-
}
41-
macro_rules! tensor {
42-
($name:expr) => {
43-
&gguf.tensors[&*$name]
44-
};
45-
}
46-
47-
let token_embd = tensor!["token_embd.weight"];
48-
let output_norm = tensor!["output_norm.weight"];
49-
let qkv0 = tensor!["blk.0.attn_qkv.weight"];
29+
use gguf::{meta, tensor};
30+
let token_embd = tensor![gguf => "token_embd.weight"];
31+
let output_norm = tensor![gguf => "output_norm.weight"];
32+
let qkv0 = tensor![gguf => "blk.0.attn_qkv.weight"];
5033
#[rustfmt::skip]
5134
let meta = LlamaMeta {
5235
dt_embd : token_embd.ty,
5336
dt_norm : output_norm.ty,
5437
dt_mat : qkv0.ty,
5538

56-
nblk : meta!(llm_block_count ),
57-
nctx : meta!(llm_context_length ),
58-
nvoc : meta!(tokenizer_ggml_tokens).len(),
59-
nh : meta!(llm_attention_head_count ),
60-
nkvh : meta!(llm_attention_head_count_kv),
61-
nexp : meta!(llm_expert_count ; 0),
62-
nexp_use: meta!(llm_expert_used_count ; 0),
63-
d : meta!(llm_embedding_length ),
64-
dh : meta!(llm_rope_dimension_count ),
65-
di : meta!(llm_feed_forward_length ),
66-
67-
epsilon : meta!(llm_attention_layer_norm_rms_epsilon; 1e-5),
68-
theta : meta!(llm_rope_freq_base ; 1e4 ),
39+
nblk : meta!(gguf => llm_block_count ),
40+
nctx : meta!(gguf => llm_context_length ),
41+
nvoc : meta!(gguf => tokenizer_ggml_tokens).len(),
42+
nh : meta!(gguf => llm_attention_head_count ),
43+
nkvh : meta!(gguf => llm_attention_head_count_kv),
44+
nexp : meta!(gguf => llm_expert_count ; 0),
45+
nexp_use: meta!(gguf => llm_expert_used_count ; 0),
46+
d : meta!(gguf => llm_embedding_length ),
47+
dh : meta!(gguf => llm_rope_dimension_count ),
48+
di : meta!(gguf => llm_feed_forward_length ),
49+
50+
epsilon : meta!(gguf => llm_attention_layer_norm_rms_epsilon; 1e-5),
51+
theta : meta!(gguf => llm_rope_freq_base ; 1e4 ),
6952
};
7053

7154
#[rustfmt::skip]
7255
let blocks = (0..meta.nblk)
7356
.map(|i| BlkStorage {
74-
attn_norm: tensor![format!("blk.{i}.attn_norm.weight" )].data,
75-
attn_qkv: tensor![format!("blk.{i}.attn_qkv.weight" )].data,
76-
attn_o: tensor![format!("blk.{i}.attn_output.weight")].data,
77-
ffn_norm: tensor![format!("blk.{i}.ffn_norm.weight" )].data,
57+
attn_norm: tensor![gguf => format!("blk.{i}.attn_norm.weight" )].data,
58+
attn_qkv: tensor![gguf => format!("blk.{i}.attn_qkv.weight" )].data,
59+
attn_o: tensor![gguf => format!("blk.{i}.attn_output.weight")].data,
60+
ffn_norm: tensor![gguf => format!("blk.{i}.ffn_norm.weight" )].data,
7861
ffn_gate_inp: if !meta.is_moe() { None }
79-
else { Some(tensor![format!("blk.{i}.ffn_gate_inp.weight" )].data) },
80-
ffn_gate_up : if !meta.is_moe() { tensor![format!("blk.{i}.ffn_gate_up.weight" )].data }
81-
else { tensor![format!("blk.{i}.ffn_gate_up_exps.weight")].data },
82-
ffn_down : if !meta.is_moe() { tensor![format!("blk.{i}.ffn_down.weight" )].data }
83-
else { tensor![format!("blk.{i}.ffn_down_exps.weight" )].data },
62+
else { Some(tensor![gguf => format!("blk.{i}.ffn_gate_inp.weight" )].data) },
63+
ffn_gate_up : if !meta.is_moe() { tensor![gguf => format!("blk.{i}.ffn_gate_up.weight" )].data }
64+
else { tensor![gguf => format!("blk.{i}.ffn_gate_up_exps.weight")].data },
65+
ffn_down : if !meta.is_moe() { tensor![gguf => format!("blk.{i}.ffn_down.weight" )].data }
66+
else { tensor![gguf => format!("blk.{i}.ffn_down_exps.weight" )].data },
8467
})
8568
.collect();
8669

0 commit comments

Comments
 (0)