@@ -72,6 +72,8 @@ pub trait WeightLoader {
7272 fn resampler_attn_k < ' a > ( & ' a self , queue : & ' a QueueOf < Self :: Hardware > ) -> [ Self :: Memory < ' a > ; 2 ] ;
7373 fn resampler_attn_v < ' a > ( & ' a self , queue : & ' a QueueOf < Self :: Hardware > ) -> [ Self :: Memory < ' a > ; 2 ] ;
7474 fn resampler_attn_o < ' a > ( & ' a self , queue : & ' a QueueOf < Self :: Hardware > ) -> [ Self :: Memory < ' a > ; 2 ] ;
75+ fn resampler_ln_post < ' a > ( & ' a self , queue : & ' a QueueOf < Self :: Hardware > )
76+ -> [ Self :: Memory < ' a > ; 2 ] ;
7577 fn resampler_proj < ' a > ( & ' a self , queue : & ' a QueueOf < Self :: Hardware > ) -> Self :: Memory < ' a > ;
7678}
7779
@@ -268,7 +270,7 @@ where
268270
269271 let weights = & self . weights . weights ;
270272 let q0 = Tensor :: new ( dt, & [ dq, d] ) . map ( |_| weights. resampler_q ( queue) ) ;
271- let ln_qkv = Tensor :: new ( dt_norm, & [ d] ) ;
273+ let ln = Tensor :: new ( dt_norm, & [ d] ) ;
272274
273275 let q = Tensor :: new ( dt, q0. shape ( ) ) ;
274276 let kv = Tensor :: new ( dt, & [ np, d] ) ;
@@ -285,11 +287,11 @@ where
285287 self . mat_mul ( & mut v, & x, ( w, None ) , workspace, queue_alloc) ?;
286288
287289 let [ w, b] = weights. resampler_ln_q ( queue) ;
288- let ln_q = [ ln_qkv . clone ( ) . map ( |_| w) , ln_qkv . clone ( ) . map ( |_| b) ] ;
290+ let ln_q = [ ln . clone ( ) . map ( |_| w) , ln . clone ( ) . map ( |_| b) ] ;
289291 self . layer_norm ( & mut q, & q0, ln_q, workspace, queue_alloc) ?;
290292
291293 let [ w, b] = weights. resampler_ln_kv ( queue) ;
292- let ln_v = [ ln_qkv . clone ( ) . map ( |_| w) , ln_qkv . clone ( ) . map ( |_| b) ] ;
294+ let ln_v = [ ln . clone ( ) . map ( |_| w) , ln . clone ( ) . map ( |_| b) ] ;
293295 let inplace = unsafe { v. map_slice_static ( ) } ;
294296 self . layer_norm ( & mut v, & inplace, ln_v, workspace, queue_alloc) ?;
295297
@@ -315,34 +317,34 @@ where
315317 let [ w, b] = weights. resampler_attn_o ( queue) ;
316318 let attn_o = ( attn_w. clone ( ) . map ( |_| w) , Some ( attn_b. clone ( ) . map ( |_| b) ) ) ;
317319
318- let q_ = Tensor :: new ( dt, & [ batch, dq, d] ) ;
319- let k_ = Tensor :: new ( dt, & [ np, d] ) ;
320- let v_ = Tensor :: new ( dt, & [ np, d] ) ;
321- let o_ = Tensor :: new ( dt, & [ batch * dq, d] ) ;
320+ let qo = Tensor :: new ( dt, & [ batch * dq, d] ) ;
322321
323- let ( buf, workspace) = workspace. split_at_mut ( * q_ . get ( ) ) ;
324- let mut q_ = q_ . map ( |_| buf) ;
322+ let ( buf, workspace) = workspace. split_at_mut ( * qo . get ( ) ) ;
323+ let mut q_ = qo . clone ( ) . map ( |_| buf) ;
325324 {
326- let mut q_ = q_. map_slice_mut ( ) . index ( 0 , 0 ) ;
327- self . mat_mul ( & mut q_, & q, attn_q, workspace, queue_alloc) ?
328- }
329- if batch > 1 {
330- split ! ( q_ => q0, q1; [ 1 , batch - 1 ] @ 0 ) ;
331- let q0 = q0. broadcast ( 0 , batch - 1 ) ;
332- let mut q1 = q1;
333- self . rearrange ( & mut q1, & q0, workspace, queue_alloc) ?
325+ let mut q_ = q_. map_slice_mut ( ) . tile ( 0 , & [ batch, dq] ) ;
326+ {
327+ let mut q_ = q_. map_slice_mut ( ) . index ( 0 , 0 ) ;
328+ self . mat_mul ( & mut q_, & q, attn_q, workspace, queue_alloc) ?
329+ }
330+ if batch > 1 {
331+ split ! ( q_ => q0, q1; [ 1 , batch - 1 ] @ 0 ) ;
332+ let q0 = q0. broadcast ( 0 , batch - 1 ) ;
333+ let mut q1 = q1;
334+ self . rearrange ( & mut q1, & q0, workspace, queue_alloc) ?
335+ }
334336 }
335- let mut q_ = q_. merge ( 0 ..2 ) . unwrap ( ) ;
337+ {
338+ let kv = Tensor :: new ( dt, & [ np, d] ) ;
336339
337- let ( buf, workspace) = workspace. split_at_mut ( * k_ . get ( ) ) ;
338- let mut k_ = k_ . map ( |_| buf) ;
339- self . mat_mul ( & mut k_, & k, attn_k, workspace, queue_alloc) ?;
340+ let ( buf, workspace) = workspace. split_at_mut ( * kv . get ( ) ) ;
341+ let mut k_ = kv . clone ( ) . map ( |_| buf) ;
342+ self . mat_mul ( & mut k_, & k, attn_k, workspace, queue_alloc) ?;
340343
341- let ( buf, workspace) = workspace. split_at_mut ( * v_ . get ( ) ) ;
342- let mut v_ = v_ . map ( |_| buf) ;
343- self . mat_mul ( & mut v_, & v, attn_v, workspace, queue_alloc) ?;
344+ let ( buf, workspace) = workspace. split_at_mut ( * kv . get ( ) ) ;
345+ let mut v_ = kv . map ( |_| buf) ;
346+ self . mat_mul ( & mut v_, & v, attn_v, workspace, queue_alloc) ?;
344347
345- {
346348 let nh_dh = & [ d / dh, dh] ;
347349 let q = q_. map_slice_mut ( ) . tile ( 1 , nh_dh) . transpose ( & [ 1 , 0 ] ) ;
348350 let k = k_. tile ( 1 , nh_dh) . transpose ( & [ 1 , 0 ] ) ;
@@ -361,10 +363,15 @@ where
361363 }
362364 let o = q_;
363365
364- let ( buf, workspace) = workspace. split_at_mut ( * o_ . get ( ) ) ;
365- let mut o_ = o_ . map ( |_| buf) ;
366+ let ( buf, workspace) = workspace. split_at_mut ( * qo . get ( ) ) ;
367+ let mut o_ = qo . map ( |_| buf) ;
366368 self . mat_mul ( & mut o_, & o, attn_o, workspace, queue_alloc) ?;
367369
370+ let [ w, b] = weights. resampler_ln_post ( queue) ;
371+ let ln_post = [ ln. clone ( ) . map ( |_| w) , ln. clone ( ) . map ( |_| b) ] ;
372+ let inplace = unsafe { o_. map_slice_static ( ) } ;
373+ self . layer_norm ( & mut o_, & inplace, ln_post, workspace, queue_alloc) ?;
374+
368375 let mut out = o;
369376 let w = attn_w. map ( |_| weights. resampler_proj ( queue) ) ;
370377 self . mat_mul ( & mut out, & o_, ( w, None ) , workspace, queue_alloc) ?
0 commit comments