Skip to content

Commit c469a04

Browse files
committed
添加rope
1 parent 247636a commit c469a04

File tree

3 files changed

+41
-86
lines changed

3 files changed

+41
-86
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ itertools = "0.13"
4141
env_logger = "0.11"
4242
build-script-cfg = "0.0"
4343

44-
operators = { git = "https://github.com/onenewcode/operators-rs", rev = "f4a83f7", default-features = false }
44+
operators = { path = "/home/ztf/operators-rs/operators"}
4545

4646
search-cl-tools = { git = "https://github.com/InfiniTensor/clrt", rev = "f69b160" }
4747
search-infini-tools = { git = "https://github.com/InfiniTensor/infini-rt", rev = "e8362c3" }

models/llama/common/src/compute.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,7 @@ where
480480
cos_layout: cos.layout(),
481481
cos_base: cos.base(),
482482
theta: self.meta.theta,
483+
rope_type: rope::RopeType::Rope,
483484
},
484485
workspace,
485486
queue_alloc,

models/minicpm3/common/src/compute.rs

Lines changed: 39 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@ use operators::{
99
add::{self, Add},
1010
all_reduce::{self, AllReduce, ReduceOp},
1111
attention::{self, Attention},
12-
attention_kv_cached::{AttnKVCached},
12+
attention_kv_cached::AttnKVCached,
1313
fuesd_softmax::AttnMask,
1414
mat_mul::{self, MatMul},
1515
rearrange::{self, Rearrange},
1616
rms_norm::{self, RmsNorm},
17-
rope::{self, Rope, SinCosTable},
17+
rope::{self, Rope, Seq, SinCosTable},
1818
swiglu::{self, Swiglu},
1919
ByteOf, Hardware, LaunchError, Operator, QueueAlloc, QueueOf, TopoNode, Workspace,
2020
};
@@ -135,7 +135,7 @@ where
135135
{
136136
let Args {
137137
embd: mut x,
138-
logits,
138+
mut logits,
139139
requests,
140140
num_tokens: nt,
141141
sin_cos,
@@ -151,7 +151,6 @@ where
151151
dkv_lora,
152152
dv,
153153
dt_embd,
154-
155154
..
156155
} = self.meta;
157156
// llama.cpp 定义死
@@ -171,12 +170,23 @@ where
171170
let (buf, workspace) = workspace.split_at_mut(*x1.get());
172171
let mut x1 = x1.map(|_| buf);
173172

174-
// 经行 attention
175-
let attn = tensor(&[nt, nh, dv]);
176-
let (buf, workspace) = workspace.split_at_mut(*attn.get());
177-
let mut attn = attn.map(|_| buf);
178173

179174
let queue = queue_alloc.queue();
175+
176+
let sin = sin_cos.clone().index(0, 0);
177+
let cos = sin_cos.index(0, 1);
178+
179+
let pos = Tensor::new(self.dt_pos, &[nt]).map(|_| {
180+
Ops::Rope::build_pos(
181+
self.dt_pos,
182+
nt,
183+
requests.iter().map(|req| Seq {
184+
pos: req.pos,
185+
len: req.seq_len,
186+
}),
187+
queue_alloc,
188+
)
189+
});
180190
// 缩放
181191
let inplace = unsafe { x.map_slice_static() };
182192
self.scale(&mut x, &inplace, scale_emb, workspace, queue_alloc)?;
@@ -232,95 +242,31 @@ where
232242

233243
split_mut!(kv => k_nope ,v ; [dnope , dv ] @ 2);
234244

235-
/// longrope
236-
pub fn longrope(
237-
embd: &mut [f32],
238-
pos: f32,
239-
theta: f32,
240-
long_factor: &[f32],
241-
short_factor: &[f32],
242-
max_pos: f32,
243-
origin_max_pos: f32,
244-
) {
245-
use std::slice::from_raw_parts_mut;
246-
// 计算 scaling_factor
247-
let scaling_factor =
248-
1.0 + ((max_pos / origin_max_pos).ln() / origin_max_pos.ln()).sqrt();
249-
let factor = if pos > origin_max_pos {
250-
long_factor
251-
} else {
252-
short_factor
253-
};
254-
let dh = embd.len() / 2;
255-
let embd =
256-
unsafe { from_raw_parts_mut(embd.as_mut_ptr().cast::<[f32; 2]>(), dh) };
257-
for (i, pair) in embd.iter_mut().enumerate() {
258-
let theta = theta.powf(-(i as f32 / dh as f32));
259-
let freq = pos * theta * factor.get(i).unwrap().recip();
260-
let (sin, cos) = freq.sin_cos();
261-
let (sin, cos) = (sin * scaling_factor, cos * scaling_factor);
262-
let [a, b] = *pair;
263-
*pair = [a * cos - b * sin, a * sin + b * cos];
264-
}
265-
}
266-
let cast = |t: *const f32| -> &'static [f32] {
267-
unsafe { std::slice::from_raw_parts(t, dh / 2) }
268-
};
269-
let [long_factor, short_factor] = self.weights.factor(queue);
270-
let long_factor = cast(long_factor.base().cast());
271-
let short_factor = cast(short_factor.base().cast());
272-
273245
// k [1, 3840]
274246
let k = tensor(&[nt, nh, dk]);
275247
let (buf, workspace) = workspace.split_at_mut(*k.get());
276248
let k = k.map(|_| buf);
277249

278250
split_mut!(k => k_nope_r ,k_rope_r ; [dnope, dh] @ 2);
279251

280-
let pos = requests.last().unwrap().pos as f32;
281-
let (max_pos, origin_max_pos) = (100f32, 100f32);
282-
283-
// q 嵌入
284-
(0..nh).for_each(|i| {
285-
let tmp_q = unsafe {
286-
std::slice::from_raw_parts_mut(
287-
q_rope.base_mut().cast::<f32>().add(i * 32),
288-
32,
289-
)
290-
};
291-
longrope(
292-
tmp_q,
293-
pos,
294-
self.meta.theta,
295-
long_factor,
296-
short_factor,
297-
max_pos,
298-
origin_max_pos,
299-
);
300-
});
301-
// k 嵌入
302-
303-
let k_rope_1 =
304-
unsafe { std::slice::from_raw_parts_mut(k_rope.base_mut().cast::<f32>(), 32) };
305-
longrope(
306-
k_rope_1,
307-
pos,
308-
self.meta.theta,
309-
long_factor,
310-
short_factor,
311-
max_pos,
312-
origin_max_pos,
313-
);
314-
315-
// 经行广播和拷贝
316-
let k_rope = k_rope.tile(1, &[1, dh]).broadcast(1, nh);
252+
self.rope(&mut q_rope, &pos, &sin, &cos, workspace, queue_alloc)?;
253+
let mut k_rope = k_rope.tile(1, &[1, dh]);
254+
self.rope(&mut k_rope, &pos, &sin, &cos, workspace, queue_alloc)?;
255+
let k_rope = k_rope.broadcast(1, nh);
317256
self.rearrange(&mut k_rope_r, &k_rope, workspace, queue_alloc)?;
318257
self.rearrange(&mut k_nope_r, &k_nope, workspace, queue_alloc)?;
319258

259+
let pos = requests.last().unwrap().pos as f32;
320260
let mut q = q3.transpose(&[1, 0]);
321261
let k = k.map_slice().transpose(&[1, 0]);
322262
let v = v.map_slice_mut().transpose(&[1, 0]);
263+
// 经行 attention
264+
let attn = tensor(&[nt, nh, dv]);
265+
let (buf, workspace) = workspace.split_at_mut(*attn.get());
266+
let mut attn = attn.map(|_| buf);
267+
323268
let mut attn = unsafe { attn.map_slice_mut().transpose(&[1, 0]) };
269+
let pos = requests.last().unwrap().pos as f32;
324270
self.attnention(
325271
&mut q,
326272
&k,
@@ -378,8 +324,7 @@ where
378324
if logits.shape()[0] == 0 {
379325
return Ok(());
380326
}
381-
Ops::debug(&x, queue);
382-
todo!();
327+
383328
// 集中要采样的 token
384329
// NOTICE: 输入之前将请求按 seq len 升序排列可降低移动开销
385330
let mut dst = 0;
@@ -404,6 +349,8 @@ where
404349
self.rms_norm(&mut x, &inplace, &w, workspace, queue_alloc)?
405350
}
406351
let w = self.weights.output(queue);
352+
Ops::debug(&x, queue);
353+
todo!();
407354
self.mat_mul(&mut logits, 0., &x, &w, 1., workspace, queue_alloc)
408355
}
409356
}
@@ -490,6 +437,7 @@ where
490437
Cos: Deref<Target = [ByteOf<Ops::Hardware>]>,
491438
QA: QueueAlloc<Hardware = Ops::Hardware>,
492439
{
440+
let [long, short] = self.weights.factor(queue_alloc.queue());
493441
self.rope.launch(
494442
&rope::Args {
495443
t_layout: t.layout(),
@@ -501,6 +449,12 @@ where
501449
cos_layout: cos.layout(),
502450
cos_base: cos.base(),
503451
theta: self.meta.theta,
452+
rope_type: rope::RopeType::Long {
453+
long: long.base(),
454+
short: short.base(),
455+
max_pos: 100,
456+
origin_pos: 100,
457+
},
504458
},
505459
workspace,
506460
queue_alloc,

0 commit comments

Comments
 (0)