Skip to content

Commit e14c588

Browse files
committed
feat(llama): 初步实现 llama-moe
Signed-off-by: YdrMaster <[email protected]>
1 parent 6d3cf32 commit e14c588

File tree

2 files changed

+149
-25
lines changed

2 files changed

+149
-25
lines changed

models/llama/common-cpu/src/lib.rs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ pub struct Weights<'w> {
3131
weight_cache: RefCell<WeightCache>,
3232
dt_embd: DigitLayout,
3333
dt_mat: DigitLayout,
34+
nexp: usize,
3435
size_qkv: usize,
3536
size_o: usize,
3637
size_gate_up: usize,
@@ -70,6 +71,16 @@ where
7071
{
7172
println!("{tensor}");
7273
}
74+
75+
fn memcpy_d2h<T: Copy>(
76+
dst: &mut [T],
77+
src: &[ByteOf<Self::Hardware>],
78+
_queue: &QueueOf<Self::Hardware>,
79+
) {
80+
let count = size_of_val(dst);
81+
assert_eq!(size_of_val(src), count);
82+
unsafe { std::ptr::copy_nonoverlapping(src.as_ptr(), dst.as_mut_ptr().cast::<u8>(), count) }
83+
}
7384
}
7485

7586
impl<'w> Weights<'w> {
@@ -130,6 +141,7 @@ impl<'w> Weights<'w> {
130141
weight_cache,
131142
dt_embd: meta.dt_embd,
132143
dt_mat: meta.dt_mat,
144+
nexp: meta.nexp,
133145
size_qkv,
134146
size_o,
135147
size_gate_up,
@@ -280,6 +292,30 @@ impl WeightLoader for Weights<'_> {
280292
)
281293
}
282294

295+
fn load_moe<'a>(
296+
&'a self,
297+
which: BlkWeight,
298+
iblk: usize,
299+
iexp: usize,
300+
_queue: &'a QueueOf<Self::Hardware>,
301+
) -> Self::Weight<'a> {
302+
let &Self {
303+
ref blks,
304+
dt_embd,
305+
dt_mat,
306+
nexp,
307+
..
308+
} = self;
309+
assert_eq!(dt_embd, dt_mat);
310+
let w = match which {
311+
BlkWeight::FfnGateUp => &*blks[iblk].ffn_gate_up,
312+
BlkWeight::FfnDown => &*blks[iblk].ffn_down,
313+
_ => unreachable!(),
314+
};
315+
let one = w.len() / nexp;
316+
Dequant::Borrowed(&w[iexp * one..][..one])
317+
}
318+
283319
#[inline]
284320
fn output_norm(&self, _queue: &QueueOf<Self::Hardware>) -> Self::Weight<'_> {
285321
Dequant::Borrowed(self.output_norm)

models/llama/common/src/compute.rs

Lines changed: 113 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
use super::{args::Args, LlamaMeta};
2-
use gguf::ggml_quants::digit_layout::{types as ty, DigitLayout};
2+
use gguf::ggml_quants::{
3+
digit_layout::{types as ty, DigitLayout},
4+
f16,
5+
};
36
use itertools::izip;
47
use operators::{
58
all_reduce::{self, AllReduce, ReduceOp},
@@ -12,7 +15,7 @@ use operators::{
1215
ByteOf, Hardware, LaunchError, Operator, QueueAlloc, QueueOf, TopoNode, Workspace,
1316
};
1417
use std::ops::{Deref, DerefMut};
15-
use tensor::{split, Tensor};
18+
use tensor::{split, Blob, Tensor};
1619

1720
pub trait Operators {
1821
type Hardware: Hardware;
@@ -29,6 +32,14 @@ pub trait Operators {
2932
where
3033
T: Deref<Target = [ByteOf<Self::Hardware>]>;
3134

35+
fn memcpy_d2h<T: Copy>(
36+
_dst: &mut [T],
37+
_src: &[ByteOf<Self::Hardware>],
38+
_queue: &QueueOf<Self::Hardware>,
39+
) {
40+
todo!()
41+
}
42+
3243
fn build_sin_cos<QA>(
3344
dt: DigitLayout,
3445
nctx: usize,
@@ -68,6 +79,16 @@ pub trait WeightLoader {
6879
queue: &'a QueueOf<Self::Hardware>,
6980
) -> Self::Weight<'a>;
7081

82+
fn load_moe<'a>(
83+
&'a self,
84+
_which: BlkWeight,
85+
_iblk: usize,
86+
_iexp: usize,
87+
_queue: &'a QueueOf<Self::Hardware>,
88+
) -> Self::Weight<'a> {
89+
todo!()
90+
}
91+
7192
fn output_norm<'a>(&'a self, queue: &'a QueueOf<Self::Hardware>) -> Self::Weight<'a>;
7293
fn output<'a>(&'a self, queue: &'a QueueOf<Self::Hardware>) -> Self::Weight<'a>;
7394
}
@@ -118,23 +139,30 @@ impl<Ops: Operators, W> LlamaWorker<Ops, W> {
118139

119140
pub fn workspace_size(&self, nt: usize, max_seq_len: usize, max_att_len: usize) -> usize {
120141
let LlamaMeta {
121-
dt_mat,
122142
nh,
123143
nkvh,
124-
d,
144+
nexp,
125145
dh,
126146
di,
127147
..
128148
} = self.meta;
129149

130-
let ele = dt_mat.nbytes();
131-
let embd = nt * d * ele;
132-
let qkv = nt * (nh + nkvh + nkvh) * dh * ele;
133-
let gate_up = nt * di * 2 * ele;
134-
let q = max_seq_len * nh * dh * ele;
135-
let att = nkvh * max_seq_len * max_att_len * ele;
136-
137-
embd + qkv.max(gate_up) + q + att
150+
let embd = self.meta.embd(nt);
151+
let dt = embd.dt();
152+
let embd = embd.take();
153+
154+
let qkv = Tensor::new(dt, &[nt * (nh + nkvh + nkvh), dh]).take();
155+
let q = Tensor::new(dt, &[max_seq_len, nh, dh]).take();
156+
let att = Tensor::new(dt, &[nkvh, max_seq_len, max_att_len]).take();
157+
158+
if self.meta.is_moe() {
159+
let routes = Tensor::new(dt, &[nt, nexp]).take();
160+
let gate_up = Tensor::new(dt, &[1, di * 2]).take();
161+
embd + (qkv + q + att).max(routes).max(gate_up)
162+
} else {
163+
let gate_up = Tensor::new(dt, &[nt, di * 2]).take();
164+
embd + (qkv + q + att).max(gate_up)
165+
}
138166
}
139167
}
140168

@@ -167,6 +195,7 @@ where
167195
nh,
168196
nkvh,
169197
nexp,
198+
nexp_use,
170199
dh,
171200
di,
172201
..
@@ -182,7 +211,7 @@ where
182211
let mut x1 = x1.map(|_| buf);
183212

184213
let qkv = Tensor::new(x.dt(), &[nt, (nh + nkvh + nkvh) * dh]);
185-
let gate_up = Tensor::new(x.dt(), &[nt, di * 2]);
214+
let gate_up = Tensor::new(x.dt(), &[if self.meta.is_moe() { 1 } else { nt }, di * 2]);
186215
let routes = Tensor::new(x.dt(), &[nt, nexp]);
187216

188217
let sin = sin_cos.clone().index(0, 0);
@@ -264,36 +293,81 @@ where
264293
}
265294
self.all_reduce(&mut x, workspace, queue_alloc)?;
266295

267-
if !self.meta.is_moe() {
268-
let w = self.weights.ffn_norm(iblk, queue);
269-
self.rms_norm(&mut x1, &x, &w, workspace, queue_alloc)?;
270-
drop(w);
296+
let w = self.weights.ffn_norm(iblk, queue);
297+
self.rms_norm(&mut x1, &x, &w, workspace, queue_alloc)?;
298+
drop(w);
271299

300+
if !self.meta.is_moe() {
272301
let (buf, workspace) = workspace.split_at_mut(*gate_up.get());
273302
let mut gate_up = gate_up.clone().map(|_| buf);
274303

275-
let w = self.weights.ffn_gate_up(iblk, queue);
304+
let w = self.weights.ffn_gate_up(iblk, 0, queue);
276305
self.mat_mul(&mut gate_up, 0., &x1, &w, 1., workspace, queue_alloc)?;
277306
drop(w);
278307

279308
split!(gate_up => gate, up; [di, di] @ 1);
280309
let mut gate = gate;
281310
self.swiglu(&mut gate, &up, workspace, queue_alloc)?;
282311

283-
let w = self.weights.ffn_down(iblk, queue);
312+
let w = self.weights.ffn_down(iblk, 0, queue);
284313
self.mat_mul(&mut x, residual, &gate, &w, 1., workspace, queue_alloc)?
285314
} else {
315+
let mut routes_host = routes.clone().map(Blob::new).take();
316+
// gate_inp
286317
{
287318
let (buf, workspace) = workspace.split_at_mut(*routes.get());
288-
let mut routes = routes.clone().map(|_| buf);
319+
let mut routes_dev = routes.clone().map(|_| buf);
289320

290321
let w = self.weights.ffn_gate_inp(iblk, queue);
291-
self.mat_mul(&mut routes, 0., &x, &w, 1., workspace, queue_alloc)?;
322+
self.mat_mul(&mut routes_dev, 0., &x1, &w, 1., workspace, queue_alloc)?;
292323
drop(w);
293324

294-
todo!()
325+
Ops::memcpy_d2h(&mut routes_host, routes_dev.get(), queue)
326+
}
327+
let ([], routes, []) = (unsafe { routes_host.align_to_mut::<f16>() }) else {
328+
unreachable!()
329+
};
330+
331+
for itok in (0..nt).rev() {
332+
// fused topk
333+
let mut routes = routes[itok * nexp..][..nexp]
334+
.iter()
335+
.copied()
336+
.enumerate()
337+
.collect::<Vec<_>>();
338+
339+
routes.sort_unstable_by(|&(_, a), &(_, b)| b.total_cmp(&a));
340+
let max = routes[0].1.to_f32();
341+
let mut sum = 0.;
342+
let mut moe_gate = vec![(0, 0.0f32); nexp_use];
343+
for ((i, x), gate) in std::iter::zip(routes, &mut moe_gate) {
344+
let softmax = (x.to_f32() - max).exp();
345+
*gate = (i, softmax);
346+
sum += softmax
347+
}
348+
for (_, x) in &mut moe_gate {
349+
*x /= sum
350+
}
351+
// mlp
352+
let (buf, workspace) = workspace.split_at_mut(*gate_up.get());
353+
let mut gate_up = gate_up.clone().map(|_| buf);
354+
355+
let mut x = x.map_slice_mut().slice(0, itok, 0, 1);
356+
let x1 = x1.map_slice_mut().slice(0, itok, 0, 1);
357+
358+
for (iexp, kexp) in moe_gate {
359+
let w = self.weights.ffn_gate_up(iblk, iexp, queue);
360+
self.mat_mul(&mut gate_up, 0., &x1, &w, 1., workspace, queue_alloc)?;
361+
drop(w);
362+
363+
split!(gate_up => gate, up; [di, di] @ 1);
364+
let mut gate = gate;
365+
self.swiglu(&mut gate, &up, workspace, queue_alloc)?;
366+
367+
let w = self.weights.ffn_down(iblk, iexp, queue);
368+
self.mat_mul(&mut x, residual, &gate, &w, kexp, workspace, queue_alloc)?
369+
}
295370
}
296-
// TODO MLP
297371
}
298372
self.all_reduce(&mut x, workspace, queue_alloc)?
299373
}
@@ -553,6 +627,7 @@ struct WeightDecorator<W> {
553627
ffn_gate_up: Tensor<usize>,
554628
ffn_down: Tensor<usize>,
555629
output: Tensor<usize>,
630+
is_moe: bool,
556631
weights: W,
557632
}
558633

@@ -567,6 +642,7 @@ impl LlamaMeta {
567642
ffn_gate_up: self.ffn_gate_up(Computation),
568643
ffn_down: self.ffn_down(Computation),
569644
output: self.output(),
645+
is_moe: self.is_moe(),
570646
weights,
571647
}
572648
}
@@ -627,19 +703,31 @@ impl<W: WeightLoader> WeightDecorator<W> {
627703
pub fn ffn_gate_up<'a>(
628704
&'a self,
629705
iblk: usize,
706+
iexp: usize,
630707
queue: &'a QueueOf<W::Hardware>,
631708
) -> Tensor<W::Weight<'a>> {
632-
let w = self.weights.load_blk(BlkWeight::FfnGateUp, iblk, queue);
709+
const WHICH: BlkWeight = BlkWeight::FfnGateUp;
710+
let w = if self.is_moe {
711+
self.weights.load_moe(WHICH, iblk, iexp, queue)
712+
} else {
713+
self.weights.load_blk(WHICH, iblk, queue)
714+
};
633715
self.ffn_gate_up.clone().map(|_| w)
634716
}
635717

636718
#[inline]
637719
pub fn ffn_down<'a>(
638720
&'a self,
639721
iblk: usize,
722+
iexp: usize,
640723
queue: &'a QueueOf<W::Hardware>,
641724
) -> Tensor<W::Weight<'a>> {
642-
let w = self.weights.load_blk(BlkWeight::FfnDown, iblk, queue);
725+
const WHICH: BlkWeight = BlkWeight::FfnDown;
726+
let w = if self.is_moe {
727+
self.weights.load_moe(WHICH, iblk, iexp, queue)
728+
} else {
729+
self.weights.load_blk(WHICH, iblk, queue)
730+
};
643731
self.ffn_down.clone().map(|_| w)
644732
}
645733

0 commit comments

Comments
 (0)