@@ -361,7 +361,7 @@ def forward_inference(
361361
362362 # rotate after updating the compression running k/v
363363
364- q = self .rotary_emb .rotate_queries_or_keys (q , offset = cache_len )
364+ rotated_q = self .rotary_emb .rotate_queries_or_keys (q , offset = cache_len )
365365 k = self .rotary_emb .rotate_queries_or_keys (k , offset = cache_len )
366366
367367 # handle cache, which stores the rotated
@@ -459,7 +459,7 @@ def forward_inference(
459459
460460 # remove later
461461
462- fq = rearrange (q , 'b (h gh) ... -> b h gh ...' , gh = self .num_grouped_queries )
462+ fq = rearrange (rotated_q , 'b (h gh) ... -> b h gh ...' , gh = self .num_grouped_queries )
463463
464464 fsim = einsum (fq , fk , 'b h gh i d, b h j d -> b h gh i j' ) * scale
465465
@@ -476,11 +476,12 @@ def forward_inference(
476476 v = repeat (v , 'b h ... -> b (h gh) ...' , gh = self .num_grouped_queries )
477477
478478 sliding_slice = (Ellipsis , slice (- (sliding_window + 1 ), None ), slice (None ))
479- rotated_q , rotated_k = self .rotary_emb .rotate_queries_with_cached_keys (q , k [sliding_slice ])
480479
481- sim = einsum (rotated_q , rotated_k , 'b h i d, b h j d -> b h i j' ) * scale
480+ k , v = k [sliding_slice ], v [sliding_slice ]
481+
482+ sim = einsum (rotated_q , k , 'b h i d, b h j d -> b h i j' ) * scale
482483 attn = sim .softmax (dim = - 1 )
483- sliding_window_attn_out = einsum (attn , v [ sliding_slice ] , 'b h i j, b h j d -> b h i d' )
484+ sliding_window_attn_out = einsum (attn , v , 'b h i j, b h j d -> b h i d' )
484485
485486 # combine strategies
486487
@@ -630,8 +631,8 @@ def forward(
630631
631632 # handle if number of total blocks is less than number to select for fine attention
632633
633- fq = rotated_q
634- fk = rotated_k
634+ fq = q
635+ fk = k
635636 fv = v
636637
637638 if has_selected_kv_for_fine_attn :
@@ -757,8 +758,8 @@ def forward(
757758
758759 # 3. overlapping sliding window, this is unsurprising and expected - `s` for sliding
759760
760- sq = rotated_q
761- sk = rotated_k
761+ sq = q
762+ sk = k
762763 sv = v
763764
764765 if exists (sliding_window_flex_mask ):
0 commit comments