@@ -170,17 +170,14 @@ where
170170 let dnope = dk - dh;
171171 let tensor = |shape : & [ usize ] | Tensor :: new ( dt_embd, shape) ;
172172 let x1 = tensor ( x. shape ( ) ) ;
173- let q = tensor ( & [ nt, dq_lora] ) ;
174- let kv_pe = tensor ( & [ nt, dh + dkv_lora] ) ;
173+
175174 let gate_up = tensor ( & [ nt, di * 2 ] ) ;
175+ // 空间 x+x1+q(应该可以删除)+q3+kv_pe+attn
176176 let workspace_size = * x1. get ( ) * 3 + * gate_up. get ( ) ;
177177 let mut workspace = Workspace :: new ( queue_alloc, workspace, workspace_size) ;
178178 let ( buf, workspace) = workspace. split_at_mut ( * x1. get ( ) ) ;
179179 let mut x1 = x1. map ( |_| buf) ;
180- let ( buf, workspace) = workspace. split_at_mut ( * q. get ( ) ) ;
181- let mut q = q. map ( |_| buf) ;
182- let ( buf, workspace) = workspace. split_at_mut ( * kv_pe. get ( ) ) ;
183- let mut kv_pe = kv_pe. map ( |_| buf) ;
180+
184181 // 经行 attention
185182 let attn = tensor ( & [ nt, nh, dv] ) ;
186183 let ( buf, workspace) = workspace. split_at_mut ( * attn. get ( ) ) ;
@@ -191,52 +188,53 @@ where
191188 // norm
192189 let w = self . weights . attn_norm ( iblk, queue) ;
193190 self . rms_norm ( & mut x1, & x, & w, workspace, queue_alloc) ?;
194- // if iblk==1{
195- // Ops::debug(&x1, queue);
196- // todo!();
197- // }
198191 drop ( w) ;
192+ let q = tensor ( & [ nt, dq_lora] ) ;
193+ let ( buf, workspace) = workspace. split_at_mut ( * q. get ( ) ) ;
194+ let mut q = q. map ( |_| buf) ;
195+ let w = self . weights . attn_qa ( iblk, queue) . transpose ( & [ 1 , 0 ] ) ;
196+ self . mat_mul ( & mut q, 0. , & x1, & w, 1. , workspace, queue_alloc) ?;
197+
198+ let inplace = unsafe { q. map_slice_static ( ) } ;
199+ let w = self . weights . attn_qa_norm ( iblk, queue) ;
200+ self . rms_norm ( & mut q, & inplace, & w, workspace, queue_alloc) ?;
199201 {
200- let w = self . weights . attn_qa ( iblk, queue) . transpose ( & [ 1 , 0 ] ) ;
201- self . mat_mul ( & mut q, 0. , & x1, & w, 1. , workspace, queue_alloc) ?;
202-
203- let inplace = unsafe { q. map_slice_static ( ) } ;
204- let w = self . weights . attn_qa_norm ( iblk, queue) ;
205- self . rms_norm ( & mut q, & inplace, & w, workspace, queue_alloc) ?;
206-
207- let w = self . weights . attn_qb ( iblk, queue) . transpose ( & [ 1 , 0 ] ) ;
202+ // q [1, 768] q1 [1, 3840] kv_pe [1,288] kv [1, 5120] k [1, 3840] attn [1, 2560]
208203 let q1 = tensor ( & [ nt, nh * dk] ) ;
209204 let ( buf, workspace) = workspace. split_at_mut ( * q1. get ( ) ) ;
210205 let mut q1 = q1. map ( |_| buf) ;
206+ let w = self . weights . attn_qb ( iblk, queue) . transpose ( & [ 1 , 0 ] ) ;
211207 self . mat_mul ( & mut q1, 0. , & q, & w, 1. , workspace, queue_alloc) ?;
212- let q3 = q1. tile ( 1 , & [ nh, dk] ) ;
213- let parts = [ dnope, dh] ;
214- let mut parts = q3. split ( 2 , & parts) ;
215- let _ = parts. next ( ) . unwrap ( ) ;
216- let mut q_rope_0 = parts. next ( ) . unwrap ( ) ;
217- assert ! ( parts. next( ) . is_none( ) ) ;
218- drop ( parts) ;
208+ drop ( q) ;
209+ // q3 是计算 attn 需要用到的数据,但是我们仍然需要对 q3 的的部分进行嵌入操作
210+ let mut q3 = q1. tile ( 1 , & [ nh, dk] ) ;
211+ let q2 = unsafe { q3. map_slice_static_mut ( ) } ;
212+ split_mut ! ( q2=>_q, q_rope; [ dnope, dh] @ 2 ) ;
213+
214+ // kv_pe [1,288]
215+ let kv_pe = tensor ( & [ nt, dkv_lora + dh] ) ;
216+ let ( buf, workspace) = workspace. split_at_mut ( * kv_pe. get ( ) ) ;
217+ let mut kv_pe = kv_pe. map ( |_| buf) ;
218+
219219 let w = self . weights . attn_kva ( iblk, queue) . transpose ( & [ 1 , 0 ] ) ;
220220 self . mat_mul ( & mut kv_pe, 0. , & x1, & w, 1. , workspace, queue_alloc) ?;
221221
222- split_mut ! ( kv_pe => kv_lora_0 , k_rope_0 ; [ dkv_lora, dh] @ 1 ) ;
222+ split_mut ! ( kv_pe => kv_lora , k_rope ; [ dkv_lora, dh] @ 1 ) ;
223223
224- // kv_pe
225- let kv_lora_1 = tensor ( & [ nt, dkv_lora] ) ;
226- let ( buf, workspace) = workspace. split_at_mut ( * kv_lora_1. get ( ) ) ;
227- let mut kv_lora_1 = kv_lora_1. map ( |_| buf) ;
224+ let inplace = unsafe { kv_lora. map_slice_static ( ) } ;
228225 let w = self . weights . attn_kva_norm ( iblk, queue) ;
229- self . rms_norm ( & mut kv_lora_1 , & kv_lora_0 , & w, workspace, queue_alloc) ?;
230-
231- let kv_0 = tensor ( & [ nt, nh * ( dnope + dv) ] ) ;
232- let ( buf, workspace) = workspace. split_at_mut ( * kv_0 . get ( ) ) ;
233- let mut kv_0 = kv_0 . map ( |_| buf) ;
226+ self . rms_norm ( & mut kv_lora , & inplace , & w, workspace, queue_alloc) ?;
227+ // kv X[1, 5120]
228+ let kv = tensor ( & [ nt, nh * ( dnope + dv) ] ) ;
229+ let ( buf, workspace) = workspace. split_at_mut ( * kv . get ( ) ) ;
230+ let mut kv = kv . map ( |_| buf) ;
234231 let w = self . weights . attn_kvb ( iblk, queue) . transpose ( & [ 1 , 0 ] ) ;
235- self . mat_mul ( & mut kv_0, 0. , & kv_lora_1, & w, 1. , workspace, queue_alloc) ?;
236232
237- let kv_1 = kv_0 . tile ( 1 , & [ nh , dnope + dv ] ) ;
233+ self . mat_mul ( & mut kv , 0. , & kv_lora , & w , 1. , workspace , queue_alloc ) ? ;
238234
239- split_mut ! ( kv_1 => k_nope , v ; [ dnope , dv ] @ 2 ) ;
235+ let kv = kv. tile ( 1 , & [ nh, dnope + dv] ) ;
236+
237+ split_mut ! ( kv => k_nope , v ; [ dnope , dv ] @ 2 ) ;
240238
241239 /// longrope
242240 pub fn longrope (
@@ -276,23 +274,21 @@ where
276274 let long_factor = cast ( long_factor. base ( ) . cast ( ) ) ;
277275 let short_factor = cast ( short_factor. base ( ) . cast ( ) ) ;
278276
279- // k dk
277+ // k [1, 3840]
280278 let k = tensor ( & [ nt, nh, dk] ) ;
281279 let ( buf, workspace) = workspace. split_at_mut ( * k. get ( ) ) ;
282280 let mut k = k. map ( |_| buf) ;
283- let parts = [ dnope, dh] ;
284- let mut parts = k. split ( 2 , & parts) ;
285- let mut k_nope_r = parts. next ( ) . unwrap ( ) ;
286- let mut k_rope_r = parts. next ( ) . unwrap ( ) ;
287- assert ! ( parts. next( ) . is_none( ) ) ;
281+
282+ split_mut ! ( k => k_nope_r , k_rope_r ; [ dnope, dh] @ 2 ) ;
283+
288284 let pos = requests. last ( ) . unwrap ( ) . pos as f32 ;
289285 let ( max_pos, origin_max_pos) = ( 100f32 , 100f32 ) ;
290286
291287 // q 嵌入
292288 ( 0 ..nh) . for_each ( |i| {
293289 let mut tmp_q = unsafe {
294290 std:: slice:: from_raw_parts_mut (
295- q_rope_0 . base_mut ( ) . cast :: < f32 > ( ) . offset ( ( i * 32 ) as isize ) ,
291+ q_rope . base_mut ( ) . cast :: < f32 > ( ) . offset ( ( i * 32 ) as isize ) ,
296292 32 ,
297293 )
298294 } ;
@@ -306,30 +302,23 @@ where
306302 origin_max_pos,
307303 ) ;
308304 } ) ;
305+ // k 嵌入
306+
307+ let mut k_rope_1 =
308+ unsafe { std:: slice:: from_raw_parts_mut ( k_rope. base_mut ( ) . cast :: < f32 > ( ) , 32 ) } ;
309+ longrope (
310+ & mut k_rope_1,
311+ pos,
312+ self . meta . theta ,
313+ long_factor,
314+ short_factor,
315+ max_pos,
316+ origin_max_pos,
317+ ) ;
309318
310- // println!("q {:?}",k_rope_0.shape());
311- // todo!();
312- // k 嵌入
313-
314- {
315- let mut k_rope_1 = unsafe {
316- std:: slice:: from_raw_parts_mut ( k_rope_0. base_mut ( ) . cast :: < f32 > ( ) , 32 )
317- } ;
318- longrope (
319- & mut k_rope_1,
320- pos,
321- self . meta . theta ,
322- long_factor,
323- short_factor,
324- max_pos,
325- origin_max_pos,
326- ) ;
327- }
328-
329- // TODO 未确认
330319 // 经行广播和拷贝
331- let k_rope_2 = k_rope_0 . tile ( 1 , & [ 1 , dh] ) . broadcast ( 1 , nh) ;
332- self . rearrange ( & mut k_rope_r, & k_rope_2 , workspace, queue_alloc) ?;
320+ let k_rope = k_rope . tile ( 1 , & [ 1 , dh] ) . broadcast ( 1 , nh) ;
321+ self . rearrange ( & mut k_rope_r, & k_rope , workspace, queue_alloc) ?;
333322 self . rearrange ( & mut k_nope_r, & k_nope, workspace, queue_alloc) ?;
334323
335324 let mut q = q3. transpose ( & [ 1 , 0 ] ) ;
@@ -393,7 +382,8 @@ where
393382 if logits. shape ( ) [ 0 ] == 0 {
394383 return Ok ( ( ) ) ;
395384 }
396-
385+ Ops :: debug ( & x, queue) ;
386+ todo ! ( ) ;
397387 // 集中要采样的 token
398388 // NOTICE: 输入之前将请求按 seq len 升序排列可降低移动开销
399389 let mut dst = 0 ;
0 commit comments