Skip to content

Commit a0116a3

Browse files
committed
style(llama): MOE 依赖的 todo! 项分散到各后端实现
Signed-off-by: YdrMaster <ydrml@hotmail.com>
1 parent 2896798 commit a0116a3

File tree

6 files changed

+87
-25
lines changed

6 files changed

+87
-25
lines changed

models/llama/common-cpu/src/lib.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ use std::{
1717
marker::PhantomData,
1818
mem::size_of,
1919
ops::{Deref, Range, RangeBounds},
20+
ptr::copy_nonoverlapping,
2021
slice::{from_raw_parts, from_raw_parts_mut},
2122
};
2223

@@ -69,7 +70,7 @@ where
6970
where
7071
T: Deref<Target = [ByteOf<Self::Hardware>]>,
7172
{
72-
println!("{tensor}");
73+
println!("{tensor}")
7374
}
7475

7576
fn memcpy_d2h<T: Copy>(
@@ -79,7 +80,7 @@ where
7980
) {
8081
let count = size_of_val(dst);
8182
assert_eq!(size_of_val(src), count);
82-
unsafe { std::ptr::copy_nonoverlapping(src.as_ptr(), dst.as_mut_ptr().cast::<u8>(), count) }
83+
unsafe { copy_nonoverlapping(src.as_ptr(), dst.as_mut_ptr().cast::<u8>(), count) }
8384
}
8485
}
8586

@@ -265,7 +266,7 @@ impl WeightLoader for Weights<'_> {
265266
match which {
266267
AttnQKV => dequant(dt_mat, dt_embd, attn_qkv, &mut cache[..size_qkv]),
267268
AttnO => dequant(dt_mat, dt_embd, attn_o, &mut cache[..size_o]),
268-
FfnGateInp => todo!(),
269+
FfnGateInp => todo!("dequant ffn gate inp"),
269270
FfnGateUp | FfnDown => {
270271
dequant(dt_mat, dt_embd, ffn_gate_up, &mut cache[..size_gate_up]);
271272
dequant(
@@ -284,7 +285,7 @@ impl WeightLoader for Weights<'_> {
284285
match which {
285286
AttnQKV => 0..size_qkv,
286287
AttnO => 0..size_o,
287-
FfnGateInp => todo!(),
288+
FfnGateInp => todo!("dequant ffn gate inp"),
288289
FfnGateUp => 0..size_gate_up,
289290
FfnDown => size_gate_up..size_gate_up + size_down,
290291
AttnNorm | FfnNorm => unreachable!(),

models/llama/common/src/compute.rs

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,10 @@ pub trait Operators {
3333
T: Deref<Target = [ByteOf<Self::Hardware>]>;
3434

3535
fn memcpy_d2h<T: Copy>(
36-
_dst: &mut [T],
37-
_src: &[ByteOf<Self::Hardware>],
38-
_queue: &QueueOf<Self::Hardware>,
39-
) {
40-
todo!()
41-
}
36+
dst: &mut [T],
37+
src: &[ByteOf<Self::Hardware>],
38+
queue: &QueueOf<Self::Hardware>,
39+
);
4240

4341
fn build_sin_cos<QA>(
4442
dt: DigitLayout,
@@ -81,13 +79,11 @@ pub trait WeightLoader {
8179

8280
fn load_moe<'a>(
8381
&'a self,
84-
_which: BlkWeight,
85-
_iblk: usize,
86-
_iexp: usize,
87-
_queue: &'a QueueOf<Self::Hardware>,
88-
) -> Self::Weight<'a> {
89-
todo!()
90-
}
82+
which: BlkWeight,
83+
iblk: usize,
84+
iexp: usize,
85+
queue: &'a QueueOf<Self::Hardware>,
86+
) -> Self::Weight<'a>;
9187

9288
fn output_norm<'a>(&'a self, queue: &'a QueueOf<Self::Hardware>) -> Self::Weight<'a>;
9389
fn output<'a>(&'a self, queue: &'a QueueOf<Self::Hardware>) -> Self::Weight<'a>;

models/llama/common/src/storage.rs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -174,11 +174,7 @@ impl<'w> BlkStorage<&'w [u8]> {
174174
own(o_.take())
175175
},
176176
ffn_norm: borrow(self.ffn_norm),
177-
ffn_gate_inp: if len == count {
178-
self.ffn_gate_inp.map(borrow)
179-
} else {
180-
todo!()
181-
},
177+
ffn_gate_inp: self.ffn_gate_inp.map(borrow),
182178
ffn_gate_up: if len == count {
183179
borrow(self.ffn_gate_up)
184180
} else {

models/llama/infini/src/lib.rs

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,15 @@ where
5656
queue.synchronize();
5757
host
5858
});
59-
println!("{tensor}");
59+
println!("{tensor}")
60+
}
61+
62+
fn memcpy_d2h<T: Copy>(
63+
dst: &mut [T],
64+
src: &[ByteOf<Self::Hardware>],
65+
queue: &QueueOf<Self::Hardware>,
66+
) {
67+
queue.get_device().memcpy_d2h(dst, src)
6068
}
6169
}
6270

@@ -160,6 +168,20 @@ impl WeightLoader for Weights {
160168
}
161169
}
162170

171+
fn load_moe<'a>(
172+
&'a self,
173+
which: BlkWeight,
174+
iblk: usize,
175+
_iexp: usize,
176+
_queue: &'a QueueOf<Self::Hardware>,
177+
) -> Self::Weight<'a> {
178+
let _blk = &self.0.blocks[iblk];
179+
match which {
180+
BlkWeight::FfnGateUp | BlkWeight::FfnDown => todo!(),
181+
_ => unreachable!(),
182+
}
183+
}
184+
163185
#[inline]
164186
fn output_norm(&self, _queue: &QueueOf<Self::Hardware>) -> Self::Weight<'_> {
165187
&self.output_norm

models/llama/nvidia-gpu/src/lib.rs

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,15 @@ where
146146
memcpy_d2h(&mut host, s);
147147
host
148148
});
149-
println!("{tensor}");
149+
println!("{tensor}")
150+
}
151+
152+
fn memcpy_d2h<T: Copy>(
153+
dst: &mut [T],
154+
src: &[ByteOf<Self::Hardware>],
155+
_queue: &QueueOf<Self::Hardware>,
156+
) {
157+
memcpy_d2h(dst, src)
150158
}
151159
}
152160

@@ -331,6 +339,19 @@ impl<'ctx> WeightLoader for Weights<'ctx> {
331339
}
332340
}
333341

342+
fn load_moe<'a>(
343+
&'a self,
344+
which: BlkWeight,
345+
_iblk: usize,
346+
_iexp: usize,
347+
_queue: &'a QueueOf<Self::Hardware>,
348+
) -> Self::Weight<'a> {
349+
match which {
350+
BlkWeight::FfnGateUp | BlkWeight::FfnDown => todo!(),
351+
_ => unreachable!(),
352+
}
353+
}
354+
334355
#[inline]
335356
fn output_norm(&self, _queue: &QueueOf<Self::Hardware>) -> Self::Weight<'_> {
336357
WeightResult::Borrowed(&self.output_norm)

models/llama/opencl/src/lib.rs

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use operators::{
1313
use std::{
1414
marker::PhantomData,
1515
ops::{Deref, RangeBounds},
16+
ptr::copy_nonoverlapping,
1617
};
1718

1819
pub struct Operators<N = ClDevice, R = NonAllReduce<ClDevice, Rearrange>>(PhantomData<(N, R)>);
@@ -49,7 +50,18 @@ where
4950
{
5051
let tensor = tensor.as_ref().map(|s| queue.map(s));
5152
println!("{tensor}");
52-
queue.unmap(tensor.take());
53+
queue.unmap(tensor.take())
54+
}
55+
56+
fn memcpy_d2h<T: Copy>(
57+
dst: &mut [T],
58+
src: &[ByteOf<Self::Hardware>],
59+
queue: &QueueOf<Self::Hardware>,
60+
) {
61+
assert_eq!(size_of_val(dst), size_of_val(src));
62+
let svm = queue.map(src);
63+
unsafe { copy_nonoverlapping(svm.as_ptr(), dst.as_mut_ptr().cast::<u8>(), dst.len()) }
64+
queue.unmap(svm)
5365
}
5466
}
5567

@@ -122,6 +134,20 @@ impl WeightLoader for Weights {
122134
}
123135
}
124136

137+
fn load_moe<'a>(
138+
&'a self,
139+
which: BlkWeight,
140+
iblk: usize,
141+
_iexp: usize,
142+
_queue: &'a QueueOf<Self::Hardware>,
143+
) -> Self::Weight<'a> {
144+
let _blk = &self.0.blocks[iblk];
145+
match which {
146+
BlkWeight::FfnGateUp | BlkWeight::FfnDown => todo!(),
147+
_ => unreachable!(),
148+
}
149+
}
150+
125151
#[inline]
126152
fn output_norm(&self, _queue: &QueueOf<Self::Hardware>) -> Self::Weight<'_> {
127153
&self.0.output_norm

0 commit comments

Comments
 (0)