@@ -9,12 +9,12 @@ use operators::{
99 add:: { self , Add } ,
1010 all_reduce:: { self , AllReduce , ReduceOp } ,
1111 attention:: { self , Attention } ,
12- attention_kv_cached:: { AttnKVCached } ,
12+ attention_kv_cached:: AttnKVCached ,
1313 fuesd_softmax:: AttnMask ,
1414 mat_mul:: { self , MatMul } ,
1515 rearrange:: { self , Rearrange } ,
1616 rms_norm:: { self , RmsNorm } ,
17- rope:: { self , Rope , SinCosTable } ,
17+ rope:: { self , Rope , Seq , SinCosTable } ,
1818 swiglu:: { self , Swiglu } ,
1919 ByteOf , Hardware , LaunchError , Operator , QueueAlloc , QueueOf , TopoNode , Workspace ,
2020} ;
@@ -135,7 +135,7 @@ where
135135 {
136136 let Args {
137137 embd : mut x,
138- logits,
138+ mut logits,
139139 requests,
140140 num_tokens : nt,
141141 sin_cos,
@@ -151,7 +151,6 @@ where
151151 dkv_lora,
152152 dv,
153153 dt_embd,
154-
155154 ..
156155 } = self . meta ;
157156 // llama.cpp 定义死
@@ -171,12 +170,23 @@ where
171170 let ( buf, workspace) = workspace. split_at_mut ( * x1. get ( ) ) ;
172171 let mut x1 = x1. map ( |_| buf) ;
173172
174- // 经行 attention
175- let attn = tensor ( & [ nt, nh, dv] ) ;
176- let ( buf, workspace) = workspace. split_at_mut ( * attn. get ( ) ) ;
177- let mut attn = attn. map ( |_| buf) ;
178173
179174 let queue = queue_alloc. queue ( ) ;
175+
176+ let sin = sin_cos. clone ( ) . index ( 0 , 0 ) ;
177+ let cos = sin_cos. index ( 0 , 1 ) ;
178+
179+ let pos = Tensor :: new ( self . dt_pos , & [ nt] ) . map ( |_| {
180+ Ops :: Rope :: build_pos (
181+ self . dt_pos ,
182+ nt,
183+ requests. iter ( ) . map ( |req| Seq {
184+ pos : req. pos ,
185+ len : req. seq_len ,
186+ } ) ,
187+ queue_alloc,
188+ )
189+ } ) ;
180190 // 缩放
181191 let inplace = unsafe { x. map_slice_static ( ) } ;
182192 self . scale ( & mut x, & inplace, scale_emb, workspace, queue_alloc) ?;
@@ -232,95 +242,31 @@ where
232242
233243 split_mut ! ( kv => k_nope , v ; [ dnope , dv ] @ 2 ) ;
234244
235- /// longrope
236- pub fn longrope (
237- embd : & mut [ f32 ] ,
238- pos : f32 ,
239- theta : f32 ,
240- long_factor : & [ f32 ] ,
241- short_factor : & [ f32 ] ,
242- max_pos : f32 ,
243- origin_max_pos : f32 ,
244- ) {
245- use std:: slice:: from_raw_parts_mut;
246- // 计算 scaling_factor
247- let scaling_factor =
248- 1.0 + ( ( max_pos / origin_max_pos) . ln ( ) / origin_max_pos. ln ( ) ) . sqrt ( ) ;
249- let factor = if pos > origin_max_pos {
250- long_factor
251- } else {
252- short_factor
253- } ;
254- let dh = embd. len ( ) / 2 ;
255- let embd =
256- unsafe { from_raw_parts_mut ( embd. as_mut_ptr ( ) . cast :: < [ f32 ; 2 ] > ( ) , dh) } ;
257- for ( i, pair) in embd. iter_mut ( ) . enumerate ( ) {
258- let theta = theta. powf ( -( i as f32 / dh as f32 ) ) ;
259- let freq = pos * theta * factor. get ( i) . unwrap ( ) . recip ( ) ;
260- let ( sin, cos) = freq. sin_cos ( ) ;
261- let ( sin, cos) = ( sin * scaling_factor, cos * scaling_factor) ;
262- let [ a, b] = * pair;
263- * pair = [ a * cos - b * sin, a * sin + b * cos] ;
264- }
265- }
266- let cast = |t : * const f32 | -> & ' static [ f32 ] {
267- unsafe { std:: slice:: from_raw_parts ( t, dh / 2 ) }
268- } ;
269- let [ long_factor, short_factor] = self . weights . factor ( queue) ;
270- let long_factor = cast ( long_factor. base ( ) . cast ( ) ) ;
271- let short_factor = cast ( short_factor. base ( ) . cast ( ) ) ;
272-
273245 // k [1, 3840]
274246 let k = tensor ( & [ nt, nh, dk] ) ;
275247 let ( buf, workspace) = workspace. split_at_mut ( * k. get ( ) ) ;
276248 let k = k. map ( |_| buf) ;
277249
278250 split_mut ! ( k => k_nope_r , k_rope_r ; [ dnope, dh] @ 2 ) ;
279251
280- let pos = requests. last ( ) . unwrap ( ) . pos as f32 ;
281- let ( max_pos, origin_max_pos) = ( 100f32 , 100f32 ) ;
282-
283- // q 嵌入
284- ( 0 ..nh) . for_each ( |i| {
285- let tmp_q = unsafe {
286- std:: slice:: from_raw_parts_mut (
287- q_rope. base_mut ( ) . cast :: < f32 > ( ) . add ( i * 32 ) ,
288- 32 ,
289- )
290- } ;
291- longrope (
292- tmp_q,
293- pos,
294- self . meta . theta ,
295- long_factor,
296- short_factor,
297- max_pos,
298- origin_max_pos,
299- ) ;
300- } ) ;
301- // k 嵌入
302-
303- let k_rope_1 =
304- unsafe { std:: slice:: from_raw_parts_mut ( k_rope. base_mut ( ) . cast :: < f32 > ( ) , 32 ) } ;
305- longrope (
306- k_rope_1,
307- pos,
308- self . meta . theta ,
309- long_factor,
310- short_factor,
311- max_pos,
312- origin_max_pos,
313- ) ;
314-
315- // 经行广播和拷贝
316- let k_rope = k_rope. tile ( 1 , & [ 1 , dh] ) . broadcast ( 1 , nh) ;
252+ self . rope ( & mut q_rope, & pos, & sin, & cos, workspace, queue_alloc) ?;
253+ let mut k_rope = k_rope. tile ( 1 , & [ 1 , dh] ) ;
254+ self . rope ( & mut k_rope, & pos, & sin, & cos, workspace, queue_alloc) ?;
255+ let k_rope = k_rope. broadcast ( 1 , nh) ;
317256 self . rearrange ( & mut k_rope_r, & k_rope, workspace, queue_alloc) ?;
318257 self . rearrange ( & mut k_nope_r, & k_nope, workspace, queue_alloc) ?;
319258
259+ let pos = requests. last ( ) . unwrap ( ) . pos as f32 ;
320260 let mut q = q3. transpose ( & [ 1 , 0 ] ) ;
321261 let k = k. map_slice ( ) . transpose ( & [ 1 , 0 ] ) ;
322262 let v = v. map_slice_mut ( ) . transpose ( & [ 1 , 0 ] ) ;
263+ // 经行 attention
264+ let attn = tensor ( & [ nt, nh, dv] ) ;
265+ let ( buf, workspace) = workspace. split_at_mut ( * attn. get ( ) ) ;
266+ let mut attn = attn. map ( |_| buf) ;
267+
323268 let mut attn = unsafe { attn. map_slice_mut ( ) . transpose ( & [ 1 , 0 ] ) } ;
269+ let pos = requests. last ( ) . unwrap ( ) . pos as f32 ;
324270 self . attnention (
325271 & mut q,
326272 & k,
@@ -378,8 +324,7 @@ where
378324 if logits. shape ( ) [ 0 ] == 0 {
379325 return Ok ( ( ) ) ;
380326 }
381- Ops :: debug ( & x, queue) ;
382- todo ! ( ) ;
327+
383328 // 集中要采样的 token
384329 // NOTICE: 输入之前将请求按 seq len 升序排列可降低移动开销
385330 let mut dst = 0 ;
@@ -404,6 +349,8 @@ where
404349 self . rms_norm ( & mut x, & inplace, & w, workspace, queue_alloc) ?
405350 }
406351 let w = self . weights . output ( queue) ;
352+ Ops :: debug ( & x, queue) ;
353+ todo ! ( ) ;
407354 self . mat_mul ( & mut logits, 0. , & x, & w, 1. , workspace, queue_alloc)
408355 }
409356}
@@ -490,6 +437,7 @@ where
490437 Cos : Deref < Target = [ ByteOf < Ops :: Hardware > ] > ,
491438 QA : QueueAlloc < Hardware = Ops :: Hardware > ,
492439 {
440+ let [ long, short] = self . weights . factor ( queue_alloc. queue ( ) ) ;
493441 self . rope . launch (
494442 & rope:: Args {
495443 t_layout : t. layout ( ) ,
@@ -501,6 +449,12 @@ where
501449 cos_layout : cos. layout ( ) ,
502450 cos_base : cos. base ( ) ,
503451 theta : self . meta . theta ,
452+ rope_type : rope:: RopeType :: Long {
453+ long : long. base ( ) ,
454+ short : short. base ( ) ,
455+ max_pos : 100 ,
456+ origin_pos : 100 ,
457+ } ,
504458 } ,
505459 workspace,
506460 queue_alloc,
0 commit comments