@@ -149,6 +149,7 @@ where
149149 {
150150 let time = Instant :: now ( ) ;
151151 let Args {
152+ img_embd : proj_q,
152153 raw,
153154 pos,
154155 pos_resampler,
@@ -317,10 +318,7 @@ where
317318 let [ w, b] = weights. resampler_attn_o ( queue) ;
318319 let attn_o = ( attn_w. clone ( ) . map ( |_| w) , Some ( attn_b. clone ( ) . map ( |_| b) ) ) ;
319320
320- let qo = Tensor :: new ( dt, & [ batch * dq, d] ) ;
321-
322- let ( buf, workspace) = workspace. split_at_mut ( * qo. get ( ) ) ;
323- let mut q_ = qo. clone ( ) . map ( |_| buf) ;
321+ let mut q_ = proj_q;
324322 {
325323 let mut q_ = q_. map_slice_mut ( ) . tile ( 0 , & [ batch, dq] ) ;
326324 {
@@ -363,18 +361,19 @@ where
363361 }
364362 let o = q_;
365363
366- let ( buf, workspace) = workspace. split_at_mut ( * qo. get ( ) ) ;
367- let mut o_ = qo. map ( |_| buf) ;
364+ let o_ = Tensor :: new ( o. dt ( ) , o. shape ( ) ) ;
365+ let ( buf, workspace) = workspace. split_at_mut ( * o_. get ( ) ) ;
366+ let mut o_ = o_. map ( |_| buf) ;
368367 self . mat_mul ( & mut o_, & o, attn_o, workspace, queue_alloc) ?;
369368
370369 let [ w, b] = weights. resampler_ln_post ( queue) ;
371370 let ln_post = [ ln. clone ( ) . map ( |_| w) , ln. clone ( ) . map ( |_| b) ] ;
372371 let inplace = unsafe { o_. map_slice_static ( ) } ;
373372 self . layer_norm ( & mut o_, & inplace, ln_post, workspace, queue_alloc) ?;
374373
375- let mut out = o;
374+ let mut img_embd = o;
376375 let w = attn_w. map ( |_| weights. resampler_proj ( queue) ) ;
377- self . mat_mul ( & mut out , & o_, ( w, None ) , workspace, queue_alloc) ?
376+ self . mat_mul ( & mut img_embd , & o_, ( w, None ) , workspace, queue_alloc) ?
378377 }
379378 }
380379
0 commit comments