Skip to content

Commit e984da5

Browse files
committed
feat(llama): 实现支持 MOE 的分布式切分及线程间张量并行
Signed-off-by: YdrMaster <ydrml@hotmail.com>
1 parent 2896798 commit e984da5

File tree

12 files changed

+156
-77
lines changed

12 files changed

+156
-77
lines changed

models/llama/common-cpu/src/infer.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ fn test_infer() {
6969
Some(s.spawn(move || {
7070
let WorkerSeed { node, tasks } = seed;
7171
let weights = Weights::new(model, range, count);
72-
let mut worker = Worker::new(id, &node, meta.clone(), weights, id == 0);
72+
let mut worker = Worker::new(id, &node, meta.clone(), weights);
7373
let mut cache = meta.kv_cache(meta.nctx).map(Blob::new);
7474
let sin_cos = <Operators as llama::Operators>::build_sin_cos(
7575
meta.dt_embd,

models/llama/common-cpu/src/lib.rs

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ use std::{
1717
marker::PhantomData,
1818
mem::size_of,
1919
ops::{Deref, Range, RangeBounds},
20+
ptr::copy_nonoverlapping,
2021
slice::{from_raw_parts, from_raw_parts_mut},
2122
};
2223

@@ -69,7 +70,7 @@ where
6970
where
7071
T: Deref<Target = [ByteOf<Self::Hardware>]>,
7172
{
72-
println!("{tensor}");
73+
println!("{tensor}")
7374
}
7475

7576
fn memcpy_d2h<T: Copy>(
@@ -79,7 +80,7 @@ where
7980
) {
8081
let count = size_of_val(dst);
8182
assert_eq!(size_of_val(src), count);
82-
unsafe { std::ptr::copy_nonoverlapping(src.as_ptr(), dst.as_mut_ptr().cast::<u8>(), count) }
83+
unsafe { copy_nonoverlapping(src.as_ptr(), dst.as_mut_ptr().cast::<u8>(), count) }
8384
}
8485
}
8586

@@ -236,13 +237,13 @@ impl WeightLoader for Weights<'_> {
236237

237238
#[rustfmt::skip]
238239
match which {
239-
AttnNorm => return Borrowed(attn_norm ),
240-
AttnQKV if dt_mat == dt_embd => return Borrowed(attn_qkv ),
241-
AttnO if dt_mat == dt_embd => return Borrowed(attn_o ),
242-
FfnNorm => return Borrowed(ffn_norm ),
243-
FfnGateInp if dt_mat == dt_embd => return Borrowed(ffn_gate_inp.as_ref().unwrap()),
244-
FfnGateUp if dt_mat == dt_embd => return Borrowed(ffn_gate_up),
245-
FfnDown if dt_mat == dt_embd => return Borrowed(ffn_down ),
240+
AttnNorm => return Borrowed(attn_norm ),
241+
AttnQKV if dt_mat == dt_embd => return Borrowed(attn_qkv ),
242+
AttnO if dt_mat == dt_embd => return Borrowed(attn_o ),
243+
FfnNorm => return Borrowed(ffn_norm ),
244+
FfnGateInp if dt_mat == dt_embd => return Borrowed(ffn_gate_inp),
245+
FfnGateUp if dt_mat == dt_embd => return Borrowed(ffn_gate_up ),
246+
FfnDown if dt_mat == dt_embd => return Borrowed(ffn_down ),
246247
_ => {}
247248
};
248249

@@ -265,7 +266,7 @@ impl WeightLoader for Weights<'_> {
265266
match which {
266267
AttnQKV => dequant(dt_mat, dt_embd, attn_qkv, &mut cache[..size_qkv]),
267268
AttnO => dequant(dt_mat, dt_embd, attn_o, &mut cache[..size_o]),
268-
FfnGateInp => todo!(),
269+
FfnGateInp => todo!("dequant ffn gate inp"),
269270
FfnGateUp | FfnDown => {
270271
dequant(dt_mat, dt_embd, ffn_gate_up, &mut cache[..size_gate_up]);
271272
dequant(
@@ -284,7 +285,7 @@ impl WeightLoader for Weights<'_> {
284285
match which {
285286
AttnQKV => 0..size_qkv,
286287
AttnO => 0..size_o,
287-
FfnGateInp => todo!(),
288+
FfnGateInp => todo!("dequant ffn gate inp"),
288289
FfnGateUp => 0..size_gate_up,
289290
FfnDown => size_gate_up..size_gate_up + size_down,
290291
AttnNorm | FfnNorm => unreachable!(),

models/llama/common/src/compute.rs

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,10 @@ pub trait Operators {
3333
T: Deref<Target = [ByteOf<Self::Hardware>]>;
3434

3535
fn memcpy_d2h<T: Copy>(
36-
_dst: &mut [T],
37-
_src: &[ByteOf<Self::Hardware>],
38-
_queue: &QueueOf<Self::Hardware>,
39-
) {
40-
todo!()
41-
}
36+
dst: &mut [T],
37+
src: &[ByteOf<Self::Hardware>],
38+
queue: &QueueOf<Self::Hardware>,
39+
);
4240

4341
fn build_sin_cos<QA>(
4442
dt: DigitLayout,
@@ -81,13 +79,11 @@ pub trait WeightLoader {
8179

8280
fn load_moe<'a>(
8381
&'a self,
84-
_which: BlkWeight,
85-
_iblk: usize,
86-
_iexp: usize,
87-
_queue: &'a QueueOf<Self::Hardware>,
88-
) -> Self::Weight<'a> {
89-
todo!()
90-
}
82+
which: BlkWeight,
83+
iblk: usize,
84+
iexp: usize,
85+
queue: &'a QueueOf<Self::Hardware>,
86+
) -> Self::Weight<'a>;
9187

9288
fn output_norm<'a>(&'a self, queue: &'a QueueOf<Self::Hardware>) -> Self::Weight<'a>;
9389
fn output<'a>(&'a self, queue: &'a QueueOf<Self::Hardware>) -> Self::Weight<'a>;
@@ -105,17 +101,10 @@ pub struct LlamaWorker<Ops: Operators, W> {
105101
swiglu: Ops::Swiglu,
106102
rearrange: Ops::Rearrange,
107103
all_reduce: Ops::AllReduce,
108-
residual: bool,
109104
}
110105

111106
impl<Ops: Operators, W> LlamaWorker<Ops, W> {
112-
pub fn new(
113-
id: usize,
114-
node: &Ops::TopoNode,
115-
meta: LlamaMeta,
116-
weights: W,
117-
residual: bool,
118-
) -> Self {
107+
pub fn new(id: usize, node: &Ops::TopoNode, meta: LlamaMeta, weights: W) -> Self {
119108
let processor = node.processor();
120109
Self {
121110
id,
@@ -128,7 +117,6 @@ impl<Ops: Operators, W> LlamaWorker<Ops, W> {
128117
swiglu: Ops::Swiglu::new(processor),
129118
rearrange: Ops::Rearrange::new(processor),
130119
all_reduce: Ops::AllReduce::new(node),
131-
residual,
132120
}
133121
}
134122

@@ -199,7 +187,6 @@ where
199187
di,
200188
..
201189
} = self.meta;
202-
let residual = if self.residual { 1. } else { 0. };
203190

204191
let workspace_size = self.workspace_size(nt, max_seq_len, max_att_len);
205192
let mut workspace = Workspace::new(queue_alloc, workspace, workspace_size);
@@ -289,6 +276,7 @@ where
289276

290277
let o = q.merge(1..3).unwrap();
291278
let w = self.weights.attn_o(iblk, queue);
279+
let residual = if self.id == 0 { 1. } else { 0. };
292280
self.mat_mul(&mut x, residual, &o, &w, 1., workspace, queue_alloc)?
293281
}
294282
self.all_reduce(&mut x, workspace, queue_alloc)?;
@@ -310,6 +298,7 @@ where
310298
self.swiglu(&mut gate, &up, workspace, queue_alloc)?;
311299

312300
let w = self.weights.ffn_down(iblk, 0, queue);
301+
let residual = if self.id == 0 { 1. } else { 0. };
313302
self.mat_mul(&mut x, residual, &gate, &w, 1., workspace, queue_alloc)?
314303
} else {
315304
let mut routes_host = routes.clone().map(Blob::new).take();
@@ -336,6 +325,7 @@ where
336325
for (mut x, x1) in izip!(x, x1) {
337326
let (line, tail) = routes.split_at(nexp);
338327
routes = tail;
328+
let mut first = true;
339329
for (iexp, kexp) in self.topk_with_index(line) {
340330
let w = self.weights.ffn_gate_up(iblk, iexp, queue);
341331
self.mat_mul(&mut gate_up, 0., &x1, &w, 1., workspace, queue_alloc)?;
@@ -346,7 +336,9 @@ where
346336
self.swiglu(&mut gate, &up, workspace, queue_alloc)?;
347337

348338
let w = self.weights.ffn_down(iblk, iexp, queue);
349-
self.mat_mul(&mut x, residual, &gate, &w, kexp, workspace, queue_alloc)?
339+
let residual = if self.id == 0 || !first { 1. } else { 0. };
340+
self.mat_mul(&mut x, residual, &gate, &w, kexp, workspace, queue_alloc)?;
341+
first = false
350342
}
351343
}
352344
}

models/llama/common/src/lib.rs

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,27 +109,48 @@ impl LlamaMeta {
109109

110110
pub fn ffn_gate_up(&self, usage: TensorUsage) -> Tensor<usize> {
111111
let &Self { d, di, .. } = self;
112-
self.mat(di + di, d, usage)
112+
self.mat_ffn(di + di, d, usage)
113113
}
114114

115115
pub fn ffn_down(&self, usage: TensorUsage) -> Tensor<usize> {
116116
let &Self { d, di, .. } = self;
117-
self.mat(d, di, usage)
117+
self.mat_ffn(d, di, usage)
118118
}
119119

120120
pub fn output(&self) -> Tensor<usize> {
121121
self.token_embd().transpose(&[1, 0])
122122
}
123123

124124
fn mat(&self, row: usize, col: usize, usage: TensorUsage) -> Tensor<usize> {
125+
let &Self {
126+
dt_embd, dt_mat, ..
127+
} = self;
128+
// NOTICE: 权重矩阵以 mat 类型存储但以 embd 类型参与计算
129+
match usage {
130+
TensorUsage::Storage => Tensor::new(dt_mat, &[row, col / dt_mat.group_size()]),
131+
TensorUsage::Computation => {
132+
assert_eq!(dt_embd.group_size(), 1);
133+
Tensor::new(dt_embd, &[row, col]).transpose(&[1, 0])
134+
}
135+
}
136+
}
137+
138+
fn mat_ffn(&self, row: usize, col: usize, usage: TensorUsage) -> Tensor<usize> {
139+
let &Self {
140+
nexp,
141+
dt_embd,
142+
dt_mat,
143+
..
144+
} = self;
125145
// NOTICE: 权重矩阵以 mat 类型存储但以 embd 类型参与计算
126146
match usage {
127147
TensorUsage::Storage => {
128-
Tensor::new(self.dt_mat, &[row, col / self.dt_mat.group_size()])
148+
let nexp = if nexp == 0 { 1 } else { nexp };
149+
Tensor::new(dt_mat, &[nexp, row, col / dt_mat.group_size()])
129150
}
130151
TensorUsage::Computation => {
131-
assert_eq!(self.dt_embd.group_size(), 1);
132-
Tensor::new(self.dt_embd, &[row, col]).transpose(&[1, 0])
152+
assert_eq!(dt_embd.group_size(), 1);
153+
Tensor::new(dt_embd, &[row, col]).transpose(&[1, 0])
133154
}
134155
}
135156
}

models/llama/common/src/storage.rs

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ pub struct BlkStorage<T> {
1919
pub attn_qkv: T,
2020
pub attn_o: T,
2121
pub ffn_norm: T,
22-
pub ffn_gate_inp: Option<T>,
22+
pub ffn_gate_inp: T,
2323
pub ffn_gate_up: T,
2424
pub ffn_down: T,
2525
}
@@ -58,12 +58,12 @@ impl<'a> Storage<&'a [u8]> {
5858
attn_qkv: tensor![gguf => format!("blk.{i}.attn_qkv.weight" )].data,
5959
attn_o: tensor![gguf => format!("blk.{i}.attn_output.weight")].data,
6060
ffn_norm: tensor![gguf => format!("blk.{i}.ffn_norm.weight" )].data,
61-
ffn_gate_inp: if !meta.is_moe() { None }
62-
else { Some(tensor![gguf => format!("blk.{i}.ffn_gate_inp.weight" )].data) },
63-
ffn_gate_up : if !meta.is_moe() { tensor![gguf => format!("blk.{i}.ffn_gate_up.weight" )].data }
64-
else { tensor![gguf => format!("blk.{i}.ffn_gate_up_exps.weight")].data },
65-
ffn_down : if !meta.is_moe() { tensor![gguf => format!("blk.{i}.ffn_down.weight" )].data }
66-
else { tensor![gguf => format!("blk.{i}.ffn_down_exps.weight" )].data },
61+
ffn_gate_inp: if !meta.is_moe() { &[] }
62+
else { tensor![gguf => format!("blk.{i}.ffn_gate_inp.weight" )].data },
63+
ffn_gate_up : if !meta.is_moe() { tensor![gguf => format!("blk.{i}.ffn_gate_up.weight" )].data }
64+
else { tensor![gguf => format!("blk.{i}.ffn_gate_up_exps.weight")].data },
65+
ffn_down : if !meta.is_moe() { tensor![gguf => format!("blk.{i}.ffn_down.weight" )].data }
66+
else { tensor![gguf => format!("blk.{i}.ffn_down_exps.weight" )].data },
6767
})
6868
.collect();
6969

@@ -84,7 +84,7 @@ impl<T> BlkStorage<T> {
8484
attn_qkv: f(self.attn_qkv),
8585
attn_o: f(self.attn_o),
8686
ffn_norm: f(self.ffn_norm),
87-
ffn_gate_inp: self.ffn_gate_inp.map(&mut f),
87+
ffn_gate_inp: f(self.ffn_gate_inp),
8888
ffn_gate_up: f(self.ffn_gate_up),
8989
ffn_down: f(self.ffn_down),
9090
}
@@ -96,7 +96,7 @@ impl<T> BlkStorage<T> {
9696
attn_qkv: &self.attn_qkv,
9797
attn_o: &self.attn_o,
9898
ffn_norm: &self.ffn_norm,
99-
ffn_gate_inp: self.ffn_gate_inp.as_ref(),
99+
ffn_gate_inp: &self.ffn_gate_inp,
100100
ffn_gate_up: &self.ffn_gate_up,
101101
ffn_down: &self.ffn_down,
102102
}
@@ -174,27 +174,22 @@ impl<'w> BlkStorage<&'w [u8]> {
174174
own(o_.take())
175175
},
176176
ffn_norm: borrow(self.ffn_norm),
177-
ffn_gate_inp: if len == count {
178-
self.ffn_gate_inp.map(borrow)
179-
} else {
180-
todo!()
181-
},
177+
ffn_gate_inp: borrow(self.ffn_gate_inp),
182178
ffn_gate_up: if len == count {
183179
borrow(self.ffn_gate_up)
184180
} else {
185181
let gu = meta.ffn_gate_up(TensorMem).map(|_| self.ffn_gate_up);
186-
split!(gu => g, u; [di, di] @ 0);
182+
split!(gu => g, u; [di, di] @ 1);
187183

188184
let di = di / count;
189185

190-
let g = g.slice(0, di * start, 1, di * len);
191-
let u = u.slice(0, di * start, 1, di * len);
192-
debug_assert!(g.is_contiguous() && u.is_contiguous());
186+
let g = g.slice(1, di * start, 1, di * len);
187+
let u = u.slice(1, di * start, 1, di * len);
193188

194189
let mut ans = dis.ffn_gate_up(TensorMem).map(&mut f);
195190
{
196191
let ans = ans.map_slice_mut();
197-
split!(ans => g_, u_; [di * len , di * len] @ 0);
192+
split!(ans => g_, u_; [di * len , di * len] @ 1);
198193
let mut g_ = g_;
199194
let mut u_ = u_;
200195
rearrange(&mut g_, &g);
@@ -207,8 +202,8 @@ impl<'w> BlkStorage<&'w [u8]> {
207202
} else {
208203
let down = meta.ffn_down(TensorMem).map(|_| self.ffn_down);
209204

210-
let d = down.shape()[1] / count;
211-
let down = down.slice(1, d * start, 1, d * len);
205+
let d = down.shape()[2] / count;
206+
let down = down.slice(2, d * start, 1, d * len);
212207

213208
let mut down_ = Tensor::new(down.dt(), down.shape()).map(&mut f);
214209
rearrange(&mut down_, &down);

models/llama/infini/src/infer.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ fn test_infer() {
9090
let device = node.processor();
9191
let stream = device.stream();
9292
let weights = Weights::new(model, range, count, &stream);
93-
let mut worker = Worker::new(id, &node, meta.clone(), weights, id == 0);
93+
let mut worker = Worker::new(id, &node, meta.clone(), weights);
9494
let mut cache = meta
9595
.kv_cache(meta.nctx)
9696
.map(|size| stream.malloc::<u8>(size));

0 commit comments

Comments
 (0)