Skip to content

Commit 166d2ea

Browse files
committed
简单优化结构
1 parent 88e0b82 commit 166d2ea

File tree

1 file changed

+58
-68
lines changed

1 file changed

+58
-68
lines changed

models/minicpm3/common/src/compute.rs

Lines changed: 58 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -170,17 +170,14 @@ where
170170
let dnope = dk - dh;
171171
let tensor = |shape: &[usize]| Tensor::new(dt_embd, shape);
172172
let x1 = tensor(x.shape());
173-
let q = tensor(&[nt, dq_lora]);
174-
let kv_pe = tensor(&[nt, dh + dkv_lora]);
173+
175174
let gate_up = tensor(&[nt, di * 2]);
175+
// 空间 x+x1+q(应该可以删除)+q3+kv_pe+attn
176176
let workspace_size = *x1.get() * 3 + *gate_up.get();
177177
let mut workspace = Workspace::new(queue_alloc, workspace, workspace_size);
178178
let (buf, workspace) = workspace.split_at_mut(*x1.get());
179179
let mut x1 = x1.map(|_| buf);
180-
let (buf, workspace) = workspace.split_at_mut(*q.get());
181-
let mut q = q.map(|_| buf);
182-
let (buf, workspace) = workspace.split_at_mut(*kv_pe.get());
183-
let mut kv_pe = kv_pe.map(|_| buf);
180+
184181
// 经行 attention
185182
let attn = tensor(&[nt, nh, dv]);
186183
let (buf, workspace) = workspace.split_at_mut(*attn.get());
@@ -191,52 +188,53 @@ where
191188
// norm
192189
let w = self.weights.attn_norm(iblk, queue);
193190
self.rms_norm(&mut x1, &x, &w, workspace, queue_alloc)?;
194-
// if iblk==1{
195-
// Ops::debug(&x1, queue);
196-
// todo!();
197-
// }
198191
drop(w);
192+
let q = tensor(&[nt, dq_lora]);
193+
let (buf, workspace) = workspace.split_at_mut(*q.get());
194+
let mut q = q.map(|_| buf);
195+
let w = self.weights.attn_qa(iblk, queue).transpose(&[1, 0]);
196+
self.mat_mul(&mut q, 0., &x1, &w, 1., workspace, queue_alloc)?;
197+
198+
let inplace = unsafe { q.map_slice_static() };
199+
let w = self.weights.attn_qa_norm(iblk, queue);
200+
self.rms_norm(&mut q, &inplace, &w, workspace, queue_alloc)?;
199201
{
200-
let w = self.weights.attn_qa(iblk, queue).transpose(&[1, 0]);
201-
self.mat_mul(&mut q, 0., &x1, &w, 1., workspace, queue_alloc)?;
202-
203-
let inplace = unsafe { q.map_slice_static() };
204-
let w = self.weights.attn_qa_norm(iblk, queue);
205-
self.rms_norm(&mut q, &inplace, &w, workspace, queue_alloc)?;
206-
207-
let w = self.weights.attn_qb(iblk, queue).transpose(&[1, 0]);
202+
// q [1, 768] q1 [1, 3840] kv_pe [1,288] kv [1, 5120] k [1, 3840] attn [1, 2560]
208203
let q1 = tensor(&[nt, nh * dk]);
209204
let (buf, workspace) = workspace.split_at_mut(*q1.get());
210205
let mut q1 = q1.map(|_| buf);
206+
let w = self.weights.attn_qb(iblk, queue).transpose(&[1, 0]);
211207
self.mat_mul(&mut q1, 0., &q, &w, 1., workspace, queue_alloc)?;
212-
let q3 = q1.tile(1, &[nh, dk]);
213-
let parts = [dnope, dh];
214-
let mut parts = q3.split(2, &parts);
215-
let _ = parts.next().unwrap();
216-
let mut q_rope_0 = parts.next().unwrap();
217-
assert!(parts.next().is_none());
218-
drop(parts);
208+
drop(q);
209+
// q3 是计算 attn 需要用到的数据,但是我们仍然需要对 q3 的的部分进行嵌入操作
210+
let mut q3 = q1.tile(1, &[nh, dk]);
211+
let q2 = unsafe { q3.map_slice_static_mut() };
212+
split_mut!(q2=>_q, q_rope;[dnope, dh]@ 2);
213+
214+
// kv_pe [1,288]
215+
let kv_pe = tensor(&[nt, dkv_lora + dh]);
216+
let (buf, workspace) = workspace.split_at_mut(*kv_pe.get());
217+
let mut kv_pe = kv_pe.map(|_| buf);
218+
219219
let w = self.weights.attn_kva(iblk, queue).transpose(&[1, 0]);
220220
self.mat_mul(&mut kv_pe, 0., &x1, &w, 1., workspace, queue_alloc)?;
221221

222-
split_mut!(kv_pe => kv_lora_0, k_rope_0; [dkv_lora, dh] @ 1);
222+
split_mut!(kv_pe => kv_lora, k_rope; [dkv_lora, dh] @ 1);
223223

224-
// kv_pe
225-
let kv_lora_1 = tensor(&[nt, dkv_lora]);
226-
let (buf, workspace) = workspace.split_at_mut(*kv_lora_1.get());
227-
let mut kv_lora_1 = kv_lora_1.map(|_| buf);
224+
let inplace = unsafe { kv_lora.map_slice_static() };
228225
let w = self.weights.attn_kva_norm(iblk, queue);
229-
self.rms_norm(&mut kv_lora_1, &kv_lora_0, &w, workspace, queue_alloc)?;
230-
231-
let kv_0 = tensor(&[nt, nh * (dnope + dv)]);
232-
let (buf, workspace) = workspace.split_at_mut(*kv_0.get());
233-
let mut kv_0 = kv_0.map(|_| buf);
226+
self.rms_norm(&mut kv_lora, &inplace, &w, workspace, queue_alloc)?;
227+
// kv X[1, 5120]
228+
let kv = tensor(&[nt, nh * (dnope + dv)]);
229+
let (buf, workspace) = workspace.split_at_mut(*kv.get());
230+
let mut kv = kv.map(|_| buf);
234231
let w = self.weights.attn_kvb(iblk, queue).transpose(&[1, 0]);
235-
self.mat_mul(&mut kv_0, 0., &kv_lora_1, &w, 1., workspace, queue_alloc)?;
236232

237-
let kv_1 = kv_0.tile(1, &[nh, dnope + dv]);
233+
self.mat_mul(&mut kv, 0., &kv_lora, &w, 1., workspace, queue_alloc)?;
238234

239-
split_mut!(kv_1 => k_nope ,v ; [dnope , dv ] @ 2);
235+
let kv = kv.tile(1, &[nh, dnope + dv]);
236+
237+
split_mut!(kv => k_nope ,v ; [dnope , dv ] @ 2);
240238

241239
/// longrope
242240
pub fn longrope(
@@ -276,23 +274,21 @@ where
276274
let long_factor = cast(long_factor.base().cast());
277275
let short_factor = cast(short_factor.base().cast());
278276

279-
// k dk
277+
// k [1, 3840]
280278
let k = tensor(&[nt, nh, dk]);
281279
let (buf, workspace) = workspace.split_at_mut(*k.get());
282280
let mut k = k.map(|_| buf);
283-
let parts = [dnope, dh];
284-
let mut parts = k.split(2, &parts);
285-
let mut k_nope_r = parts.next().unwrap();
286-
let mut k_rope_r = parts.next().unwrap();
287-
assert!(parts.next().is_none());
281+
282+
split_mut!(k => k_nope_r ,k_rope_r ; [dnope, dh] @ 2);
283+
288284
let pos = requests.last().unwrap().pos as f32;
289285
let (max_pos, origin_max_pos) = (100f32, 100f32);
290286

291287
// q 嵌入
292288
(0..nh).for_each(|i| {
293289
let mut tmp_q = unsafe {
294290
std::slice::from_raw_parts_mut(
295-
q_rope_0.base_mut().cast::<f32>().offset((i * 32) as isize),
291+
q_rope.base_mut().cast::<f32>().offset((i * 32) as isize),
296292
32,
297293
)
298294
};
@@ -306,30 +302,23 @@ where
306302
origin_max_pos,
307303
);
308304
});
305+
// k 嵌入
306+
307+
let mut k_rope_1 =
308+
unsafe { std::slice::from_raw_parts_mut(k_rope.base_mut().cast::<f32>(), 32) };
309+
longrope(
310+
&mut k_rope_1,
311+
pos,
312+
self.meta.theta,
313+
long_factor,
314+
short_factor,
315+
max_pos,
316+
origin_max_pos,
317+
);
309318

310-
// println!("q {:?}",k_rope_0.shape());
311-
// todo!();
312-
// k 嵌入
313-
314-
{
315-
let mut k_rope_1 = unsafe {
316-
std::slice::from_raw_parts_mut(k_rope_0.base_mut().cast::<f32>(), 32)
317-
};
318-
longrope(
319-
&mut k_rope_1,
320-
pos,
321-
self.meta.theta,
322-
long_factor,
323-
short_factor,
324-
max_pos,
325-
origin_max_pos,
326-
);
327-
}
328-
329-
// TODO 未确认
330319
// 经行广播和拷贝
331-
let k_rope_2 = k_rope_0.tile(1, &[1, dh]).broadcast(1, nh);
332-
self.rearrange(&mut k_rope_r, &k_rope_2, workspace, queue_alloc)?;
320+
let k_rope = k_rope.tile(1, &[1, dh]).broadcast(1, nh);
321+
self.rearrange(&mut k_rope_r, &k_rope, workspace, queue_alloc)?;
333322
self.rearrange(&mut k_nope_r, &k_nope, workspace, queue_alloc)?;
334323

335324
let mut q = q3.transpose(&[1, 0]);
@@ -393,7 +382,8 @@ where
393382
if logits.shape()[0] == 0 {
394383
return Ok(());
395384
}
396-
385+
Ops::debug(&x, queue);
386+
todo!();
397387
// 集中要采样的 token
398388
// NOTICE: 输入之前将请求按 seq len 升序排列可降低移动开销
399389
let mut dst = 0;

0 commit comments

Comments
 (0)