Skip to content

Commit 255cce4

Browse files
committed
t
1 parent c469a04 commit 255cce4

File tree

3 files changed

+134
-51
lines changed

3 files changed

+134
-51
lines changed

models/minicpm3/common-cpu/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ where
4343
type Rearrange = op!(rearrange);
4444
type Scale = op!(scale);
4545
type AttnKVCached = op!(attention_kv_cached);
46+
type FuesdSoftmax = op!(fuesd_softmax);
4647
type AllReduce = R;
4748

4849
fn debug<T>(tensor: &Tensor<T>, _queue: &QueueOf<Self::Hardware>)

models/minicpm3/common/src/compute.rs

Lines changed: 133 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ use gguf::ggml_quants::digit_layout::types as ty;
33
use gguf::ggml_quants::digit_layout::DigitLayout;
44
use half::f16;
55
use itertools::Itertools;
6+
use operators::fuesd_softmax;
7+
use operators::fuesd_softmax::FusedSoftmax;
68
use operators::scale;
79
use operators::scale::Scale;
810
use operators::{
@@ -19,6 +21,7 @@ use operators::{
1921
ByteOf, Hardware, LaunchError, Operator, QueueAlloc, QueueOf, TopoNode, Workspace,
2022
};
2123
use std::ops::{Deref, DerefMut};
24+
use std::process::Output;
2225
use tensor::split_mut;
2326
use tensor::{split, Tensor};
2427

@@ -33,6 +36,7 @@ pub trait Operators {
3336
type MatMul: MatMul<Self::Hardware>;
3437
type Swiglu: Swiglu<Self::Hardware>;
3538
type Scale: Scale<Self::Hardware>;
39+
type FuesdSoftmax: FusedSoftmax<Self::Hardware>;
3640
type Rearrange: Rearrange<Self::Hardware>;
3741
type AllReduce: AllReduce<Self::Hardware, Self::TopoNode>;
3842

@@ -87,6 +91,7 @@ pub struct Minicpm3Worker<Ops: Operators, W> {
8791
mat_mul: Ops::MatMul,
8892
scale: Ops::Scale,
8993
swiglu: Ops::Swiglu,
94+
fuesd_softmax: Ops::FuesdSoftmax,
9095
rearrange: Ops::Rearrange,
9196
all_reduce: Ops::AllReduce,
9297
}
@@ -109,6 +114,7 @@ impl<Ops: Operators, W> Minicpm3Worker<Ops, W> {
109114
all_reduce: Ops::AllReduce::new(node),
110115
dt_pos: ty::U64,
111116
attention: Ops::Attention::new(processor),
117+
fuesd_softmax: Ops::FuesdSoftmax::new(processor),
112118
}
113119
}
114120

@@ -165,12 +171,11 @@ where
165171

166172
let gate_up = tensor(&[nt, di * 2]);
167173
// 空间 x+x1+q(应该可以删除)+q3+kv_pe+attn
168-
let workspace_size = *x1.get() * 3 + *gate_up.get();
174+
let workspace_size = *x1.get() * 20 + *gate_up.get();
169175
let mut workspace = Workspace::new(queue_alloc, workspace, workspace_size);
170176
let (buf, workspace) = workspace.split_at_mut(*x1.get());
171177
let mut x1 = x1.map(|_| buf);
172178

173-
174179
let queue = queue_alloc.queue();
175180

176181
let sin = sin_cos.clone().index(0, 0);
@@ -205,17 +210,15 @@ where
205210
let w = self.weights.attn_qa_norm(iblk, queue);
206211
self.rms_norm(&mut q, &inplace, &w, workspace, queue_alloc)?;
207212
{
208-
// q [1, 768] q1 [1, 3840] kv_pe [1,288] kv [1, 5120] k [1, 3840] attn [1, 2560]
209213
let q1 = tensor(&[nt, nh * dk]);
210214
let (buf, workspace) = workspace.split_at_mut(*q1.get());
211215
let mut q1 = q1.map(|_| buf);
212216
let w = self.weights.attn_qb(iblk, queue).transpose(&[1, 0]);
213217
self.mat_mul(&mut q1, 0., &q, &w, 1., workspace, queue_alloc)?;
214-
drop(q);
215-
// q3 是计算 attn 需要用到的数据,但是我们仍然需要对 q3 的的部分进行嵌入操作
218+
216219
let mut q3 = q1.tile(1, &[nh, dk]);
217220
let q2 = unsafe { q3.map_slice_static_mut() };
218-
split_mut!(q2=>_q, q_rope;[dnope, dh]@ 2);
221+
split_mut!(q2=>q_nope, q_rope;[dnope, dh]@ 2);
219222

220223
// kv_pe [1,288]
221224
let kv_pe = tensor(&[nt, dkv_lora + dh]);
@@ -224,65 +227,125 @@ where
224227

225228
let w = self.weights.attn_kva(iblk, queue).transpose(&[1, 0]);
226229
self.mat_mul(&mut kv_pe, 0., &x1, &w, 1., workspace, queue_alloc)?;
227-
230+
drop(q);
228231
split_mut!(kv_pe => kv_lora, k_rope; [dkv_lora, dh] @ 1);
229232

233+
self.rope(&mut q_rope, &pos, &sin, &cos, workspace, queue_alloc)?;
234+
let mut k_rope = k_rope.tile(1, &[1, dh]);
235+
self.rope(&mut k_rope, &pos, &sin, &cos, workspace, queue_alloc)?;
236+
let k_rope = k_rope.broadcast(1, nh);
237+
230238
let inplace = unsafe { kv_lora.map_slice_static() };
231239
let w = self.weights.attn_kva_norm(iblk, queue);
232240
self.rms_norm(&mut kv_lora, &inplace, &w, workspace, queue_alloc)?;
233241
// kv X[1, 5120]
234242
let kv = tensor(&[nt, nh * (dnope + dv)]);
235243
let (buf, workspace) = workspace.split_at_mut(*kv.get());
236244
let mut kv = kv.map(|_| buf);
237-
let w = self.weights.attn_kvb(iblk, queue).transpose(&[1, 0]);
238-
239-
self.mat_mul(&mut kv, 0., &kv_lora, &w, 1., workspace, queue_alloc)?;
240-
241-
let kv = kv.tile(1, &[nh, dnope + dv]);
242-
243-
split_mut!(kv => k_nope ,v ; [dnope , dv ] @ 2);
244245

245-
// k [1, 3840]
246-
let k = tensor(&[nt, nh, dk]);
247-
let (buf, workspace) = workspace.split_at_mut(*k.get());
248-
let k = k.map(|_| buf);
249-
250-
split_mut!(k => k_nope_r ,k_rope_r ; [dnope, dh] @ 2);
251-
252-
self.rope(&mut q_rope, &pos, &sin, &cos, workspace, queue_alloc)?;
253-
let mut k_rope = k_rope.tile(1, &[1, dh]);
254-
self.rope(&mut k_rope, &pos, &sin, &cos, workspace, queue_alloc)?;
255-
let k_rope = k_rope.broadcast(1, nh);
256-
self.rearrange(&mut k_rope_r, &k_rope, workspace, queue_alloc)?;
257-
self.rearrange(&mut k_nope_r, &k_nope, workspace, queue_alloc)?;
258-
259-
let pos = requests.last().unwrap().pos as f32;
260-
let mut q = q3.transpose(&[1, 0]);
261-
let k = k.map_slice().transpose(&[1, 0]);
262-
let v = v.map_slice_mut().transpose(&[1, 0]);
263-
// 经行 attention
264-
let attn = tensor(&[nt, nh, dv]);
265-
let (buf, workspace) = workspace.split_at_mut(*attn.get());
266-
let mut attn = attn.map(|_| buf);
267-
268-
let mut attn = unsafe { attn.map_slice_mut().transpose(&[1, 0]) };
269-
let pos = requests.last().unwrap().pos as f32;
270-
self.attnention(
271-
&mut q,
272-
&k,
273-
&v,
274-
&mut attn,
275-
pos as usize,
246+
let kv_b_proj = unsafe {
247+
self.weights
248+
.attn_kvb(iblk, queue)
249+
.tile(0, &[nh, dnope + dv])
250+
.map_slice_static()
251+
};
252+
split!(kv_b_proj=> q_absorb , out_absorb ; [dnope, dv] @ 1);
253+
let inplace = unsafe { q_nope.map_slice_static() };
254+
255+
let q_nope_0 = q_nope.map_slice().transpose(&[1, 0]);
256+
let q_nope_1 = tensor(&[nh, nt, dkv_lora]);
257+
let (buf, workspace) = workspace.split_at_mut(*q_nope_1.get());
258+
let mut q_nope = q_nope_1.map(|_| buf);
259+
self.mat_mul(
260+
&mut q_nope,
261+
0.,
262+
&q_nope_0,
263+
&q_absorb,
264+
1.,
276265
workspace,
277266
queue_alloc,
278267
)?;
279268

280-
let o = attn.transpose(&[1, 0]).merge(1..3).unwrap();
269+
drop(q3);
270+
//计算 attn_weights
271+
// todo deepseek 中会有 softmax_scale
272+
// python 代码
273+
// attn_weights = (torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.unsqueeze(-3).mT)) * self.softmax_scale
274+
let attn_weights = tensor(&[nh, nt, nt]);
275+
let (buf, workspace) = workspace.split_at_mut(*attn_weights.get());
276+
let mut attn_weights = attn_weights.map(|_| buf);
277+
{
278+
let q_rope = q_rope.transpose(&[1, 0]);
279+
let k_rope = k_rope.transpose(&[1, 2, 0]);
280+
281+
self.mat_mul(
282+
&mut attn_weights,
283+
0.,
284+
&q_rope,
285+
&k_rope,
286+
1.,
287+
workspace,
288+
queue_alloc,
289+
)?;
290+
let kv_lora = kv_lora
291+
.map_slice()
292+
.tile(0, &[1, 1])
293+
.broadcast(0, nh)
294+
.transpose(&[2, 1]);
295+
self.mat_mul(
296+
&mut attn_weights,
297+
1.,
298+
&q_nope,
299+
&kv_lora,
300+
1.,
301+
workspace,
302+
queue_alloc,
303+
)?;
304+
}
305+
// softmax
306+
self.softmax(&mut attn_weights, workspace, queue_alloc)?;
307+
// attn_output
308+
let attn_output_r = tensor(&[nt, nh, dv]);
309+
let (buf, workspace) = workspace.split_at_mut(*attn_output_r.get());
310+
let mut attn_output_r = attn_output_r.map(|_| buf);
311+
{
312+
let attn_output = tensor(&[nh, nt, dkv_lora]);
313+
let (buf, workspace) = workspace.split_at_mut(*attn_output.get());
314+
let mut attn_output = attn_output.map(|_| buf);
315+
let kv_lora = kv_lora.tile(0, &[1, 1]).broadcast(0, nh);
316+
self.mat_mul(
317+
&mut attn_output,
318+
0.,
319+
&attn_weights,
320+
&kv_lora,
321+
1.,
322+
workspace,
323+
queue_alloc,
324+
)?;
325+
let mut attn_output_r = attn_output_r.map_slice_mut().transpose(&[1, 0]);
326+
let out_absorb = out_absorb.transpose(&[2, 1]);
327+
328+
self.mat_mul(
329+
&mut attn_output_r,
330+
0.,
331+
&attn_output,
332+
&out_absorb,
333+
1.,
334+
workspace,
335+
queue_alloc,
336+
)?;
337+
}
338+
Ops::debug(&attn_output_r, queue);
339+
todo!();
340+
// let o = attn_output_r;
281341
let w = self.weights.attn_o(iblk, queue);
282-
283-
self.mat_mul(&mut x1, 0., &o, &w, s, workspace, queue_alloc)?;
284-
let inplace = unsafe { x.map_slice_static() };
285-
self.add(&mut x, &inplace, &x1, workspace, queue_alloc)?;
342+
println!("{:?}", attn_output_r.shape());
343+
println!("{:?}", w.shape());
344+
// println!("{:?}",out_absorb.shape());
345+
todo!();
346+
// self.mat_mul(&mut x1, 0., &o, &w, s, workspace, queue_alloc)?;
347+
// let inplace = unsafe { x.map_slice_static() };
348+
// self.add(&mut x, &inplace, &x1, workspace, queue_alloc)?;
286349
}
287350
let w = self.weights.ffn_norm(iblk, queue);
288351
self.rms_norm(&mut x1, &x, &w, workspace, queue_alloc)?;
@@ -594,6 +657,26 @@ where
594657
queue_alloc,
595658
)
596659
}
660+
fn softmax<A, QA>(
661+
&self,
662+
a: &mut Tensor<A>,
663+
workspace: &mut [ByteOf<Ops::Hardware>],
664+
queue_alloc: &QA,
665+
) -> Result<(), LaunchError>
666+
where
667+
A: DerefMut<Target = [ByteOf<Ops::Hardware>]>,
668+
QA: QueueAlloc<Hardware = Ops::Hardware>,
669+
{
670+
self.fuesd_softmax.launch(
671+
&fuesd_softmax::Args {
672+
att_mask: AttnMask::Causal,
673+
att_layout: a.layout(),
674+
att_base: a.base_mut(),
675+
},
676+
workspace,
677+
queue_alloc,
678+
)
679+
}
597680
fn all_reduce<X, QA>(
598681
&self,
599682
x: &mut Tensor<X>,

tensor/src/split.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ impl<T> Splitable for &[T] {
1414
self
1515
}
1616
}
17-
1817
impl<T> Splitable for &mut [T] {
1918
#[inline]
2019
fn split(&self) -> Self {

0 commit comments

Comments
 (0)