Skip to content

Commit 4ac05d7

Browse files
committed
跑通mlp
1 parent dd4bc26 commit 4ac05d7

File tree

1 file changed

+23
-22
lines changed

1 file changed

+23
-22
lines changed

models/minicpm3/common/src/compute.rs

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ where
152152
// llama.cpp 定义死
153153
let scale_emb = 12f32;
154154
let scale_depth = 1.4f32;
155-
// 提前进行缩放
155+
// 残差连接时权重缩放
156156
let s = scale_depth / (nblk as f32).sqrt();
157157
fn ggml_scale(embd: *mut f16, s: f16, l: usize) {
158158
if l == 0 {
@@ -181,6 +181,11 @@ where
181181
let mut q = q.map(|_| buf);
182182
let (buf, workspace) = workspace.split_at_mut(*kv_pe.get());
183183
let mut kv_pe = kv_pe.map(|_| buf);
184+
// 经行 attention
185+
let attn = tensor(&[nt, nh, dv]);
186+
let (buf, workspace) = workspace.split_at_mut(*attn.get());
187+
let mut attn = attn.map(|_| buf);
188+
184189
let queue = queue_alloc.queue();
185190
for iblk in 0..nblk {
186191
// norm
@@ -323,15 +328,11 @@ where
323328
let k_rope_2 = k_rope_0.tile(1, &[1, dh]).broadcast(1, nh);
324329
self.rearrange(&mut k_rope_r, &k_rope_2, workspace, queue_alloc)?;
325330
self.rearrange(&mut k_nope_r, &k_nope, workspace, queue_alloc)?;
326-
// 经行 attention
327-
let attn = tensor(&[nt, nh, dv]);
328-
let (buf, workspace) = workspace.split_at_mut(*attn.get());
329-
let mut attn = attn.map(|_| buf);
330331

331332
let mut q = q3.transpose(&[1, 0]);
332333
let k = k.map_slice().transpose(&[1, 0]);
333334
let mut v = v.map_slice_mut().transpose(&[1, 0]);
334-
let mut attn = attn.transpose(&[1, 0]);
335+
let mut attn = unsafe { attn.map_slice_mut().transpose(&[1, 0]) };
335336
self.attnention(
336337
&mut q,
337338
&k,
@@ -346,12 +347,9 @@ where
346347
let w = self.weights.attn_o(iblk, queue);
347348

348349
self.mat_mul(&mut x1, 0., &o, &w, s, workspace, queue_alloc)?;
350+
let inplace = unsafe { x.map_slice_static() };
351+
self.add(&mut x, &inplace, &x1, workspace, queue_alloc)?;
349352
}
350-
let inplace = unsafe { x.map_slice_static() };
351-
//是否给 add 加上缩放系数
352-
353-
self.add(&mut x, &inplace, &x1, workspace, queue_alloc)?;
354-
355353
let w = self.weights.ffn_norm(iblk, queue);
356354
self.rms_norm(&mut x1, &x, &w, workspace, queue_alloc)?;
357355
drop(w);
@@ -361,29 +359,32 @@ where
361359
split!(gate_up => gate, up; [di, di] @ 1);
362360
let mut gate = gate;
363361
let mut up = up;
364-
let w = self.weights.ffn_gate(iblk, queue).transpose(&[0,1]);
362+
let w = self.weights.ffn_gate(iblk, queue);
365363
self.mat_mul(&mut gate, 0., &x1, &w, 1., workspace, queue_alloc)?;
366-
// Ops::debug(&w, queue);
364+
365+
let w = self.weights.ffn_up(iblk, queue);
366+
self.mat_mul(&mut up, 0., &x1, &w, 1., workspace, queue_alloc)?;
367+
368+
self.swiglu(&mut gate, &up, workspace, queue_alloc)?;
367369

368370
fn print_first_10_elements(ptr: *const f16) {
369371
assert!(!ptr.is_null(), "Pointer must not be null");
370372

371373
unsafe {
372374
for i in 0..10 {
373-
// 逐个访问并打印前10个元素
375+
// 逐个访问并打印前 10 个元素
374376
let element = ptr.offset(i as isize).read();
375377
println!("Element {}: {:?}", i, element);
376378
}
377379
}
378380
}
379-
print_first_10_elements(w.base().cast::<f16>());
380-
todo!();
381-
382-
self.swiglu(&mut gate, &up, workspace, queue_alloc)?;
383381

384382
let w = self.weights.ffn_down(iblk, queue);
385-
let residual = if self.id == 0 { 1. } else { 0. };
386-
self.mat_mul(&mut x, residual, &gate, &w, 1., workspace, queue_alloc)?;
383+
self.mat_mul(&mut x1, 0., &gate, &w, s, workspace, queue_alloc)?;
384+
385+
let inplace = unsafe { x.map_slice_static() };
386+
self.add(&mut x, &inplace, &x1, workspace, queue_alloc)?;
387+
387388
self.all_reduce(&mut x, workspace, queue_alloc)?
388389
}
389390
if logits.shape()[0] == 0 {
@@ -808,7 +809,7 @@ impl<W: WeightLoader> WeightDecorator<W> {
808809
iblk: usize,
809810
queue: &'a QueueOf<W::Hardware>,
810811
) -> Tensor<W::Weight<'a>> {
811-
const WHICH: MiniCPM3BlkWeight = MiniCPM3BlkWeight::FfnGateUp;
812+
const WHICH: MiniCPM3BlkWeight = MiniCPM3BlkWeight::FfnGate;
812813
let w = self.weights.load_blk(WHICH, iblk, queue);
813814
self.ffn_gate.clone().map(|_| w)
814815
}
@@ -818,7 +819,7 @@ impl<W: WeightLoader> WeightDecorator<W> {
818819
iblk: usize,
819820
queue: &'a QueueOf<W::Hardware>,
820821
) -> Tensor<W::Weight<'a>> {
821-
const WHICH: MiniCPM3BlkWeight = MiniCPM3BlkWeight::FfnGateUp;
822+
const WHICH: MiniCPM3BlkWeight = MiniCPM3BlkWeight::FfnUp;
822823
let w = self.weights.load_blk(WHICH, iblk, queue);
823824
self.ffn_up.clone().map(|_| w)
824825
}

0 commit comments

Comments
 (0)