@@ -318,7 +318,7 @@ def forward(
318318
319319 # for gqa, we will average the compressed attention across each grouped queries (per key / values)
320320
321- importance_scores = reduce (importance_scores , 'b (grouped_queries h ) ... -> b h ...' , 'mean' , grouped_queries = self .num_grouped_queries )
321+ importance_scores = reduce (importance_scores , 'b (h grouped_queries ) ... -> b h ...' , 'mean' , grouped_queries = self .num_grouped_queries )
322322
323323 # handle if compress block size does not equal to the fine block size
324324 # cannot parse their equation, so will just improvise
@@ -349,7 +349,7 @@ def forward(
349349 if exists (fine_selection_flex_mask ):
350350 # flex attention for the selection for fine attention
351351
352- fk , fv , selected_block_indices = tuple (repeat (t , 'b h ... -> b (num_grouped_queries h ) ...' , num_grouped_queries = self .num_grouped_queries ) for t in (fk , fv , selected_block_indices ))
352+ fk , fv , selected_block_indices = tuple (repeat (t , 'b h ... -> b (h num_grouped_queries ) ...' , num_grouped_queries = self .num_grouped_queries ) for t in (fk , fv , selected_block_indices ))
353353
354354 fine_block_mask = fine_selection_flex_mask (selected_block_indices )
355355
@@ -413,7 +413,7 @@ def forward(
413413
414414 # fine attention
415415
416- fk , fv , fmask = tuple (repeat (t , 'b h ... -> b (num_grouped_queries h ) ...' , num_grouped_queries = self .num_grouped_queries ) for t in (fk , fv , fmask ))
416+ fk , fv , fmask = tuple (repeat (t , 'b h ... -> b (h num_grouped_queries ) ...' , num_grouped_queries = self .num_grouped_queries ) for t in (fk , fv , fmask ))
417417
418418 fsim = einsum (fq , fk , 'b h i d, b h i j d -> b h i j' ) * self .scale
419419
@@ -430,7 +430,7 @@ def forward(
430430 seq_len = fk .shape [- 2 ]
431431 fmask = causal_mask = torch .ones ((seq_len , seq_len ), device = device , dtype = torch .bool ).tril ()
432432
433- fk , fv = tuple (repeat (t , 'b h ... -> b (num_grouped_queries h ) ...' , num_grouped_queries = self .num_grouped_queries ) for t in (fk , fv ))
433+ fk , fv = tuple (repeat (t , 'b h ... -> b (h num_grouped_queries ) ...' , num_grouped_queries = self .num_grouped_queries ) for t in (fk , fv ))
434434
435435 fsim = einsum (fq , fk , 'b h i d, b h j d -> b h i j' ) * self .scale
436436
@@ -449,7 +449,7 @@ def forward(
449449 if exists (sliding_window_flex_mask ):
450450 sliding_window_attn_out = flex_attention (sq , sk , sv , block_mask = sliding_window_flex_mask , enable_gqa = True )
451451 else :
452- sk , sv = tuple (repeat (t , 'b h ... -> b (num_grouped_queries h ) ...' , num_grouped_queries = self .num_grouped_queries ) for t in (sk , sv ))
452+ sk , sv = tuple (repeat (t , 'b h ... -> b (h num_grouped_queries ) ...' , num_grouped_queries = self .num_grouped_queries ) for t in (sk , sv ))
453453
454454 sliding_window_attn_out = self .sliding_window (sq , sk , sv )
455455
0 commit comments