@@ -218,6 +218,7 @@ def forward(
218218 fk = k
219219 fv = v
220220
221+
221222 if seq_len < fine_divisible_seq_len :
222223 remainder = fine_divisible_seq_len - seq_len
223224 fk = pad_at_dim (fk , (0 , remainder ), value = 0. , dim = - 2 )
@@ -228,13 +229,30 @@ def forward(
228229
229230 selected_block_indices = pad_at_dim (selected_block_indices , (0 , remainder ), value = 0 , dim = - 2 )
230231
232+ # handle block causal diagonal in the diagram, but run experiments without to see
233+
234+ fine_window_seq = arange (fine_divisible_seq_len , device = device ) // self .selection_block_size
235+ fine_window_seq = rearrange (fine_window_seq , 'n -> n 1' ).expand_as (selected_block_indices )
236+ selected_block_indices = cat ((selected_block_indices , fine_window_seq ), dim = - 1 ) # for the block causal diagonal in fig2
237+
238+ fmask = repeat (fmask , 'b h i w -> b h i w j' , j = self .selection_block_size )
239+
240+ causal_mask = torch .ones ((self .selection_block_size ,) * 2 , device = device , dtype = torch .bool ).tril ()
241+ causal_mask = repeat (causal_mask , 'i j -> (w i) 1 j' , w = num_fine_blocks ).expand_as (fmask )
242+
243+ fmask = cat ((fmask , causal_mask ), dim = - 2 )
244+ fmask = rearrange (fmask , 'b h i w j -> b h i (w j)' )
245+
246+ # select out the spatial crops of keys / values for fine attention
247+
231248 fk = rearrange (fk , 'b h (w n) d -> b h w n d' , w = num_fine_blocks )
232249 fv = rearrange (fv , 'b h (w n) d -> b h w n d' , w = num_fine_blocks )
233- fmask = repeat (fmask , 'b h i w -> b h i (w j)' , j = self .selection_block_size )
234250
235251 fk = einx .get_at ('b h [w] j d, b h i selected -> b h i (selected j) d' , fk , selected_block_indices )
236252 fv = einx .get_at ('b h [w] j d, b h i selected -> b h i (selected j) d' , fv , selected_block_indices )
237253
254+ # fine attention
255+
238256 fsim = einsum (fq , fk , 'b h i d, b h i j d -> b h i j' ) * self .scale
239257
240258 fsim = fsim .masked_fill (fmask , mask_value )
0 commit comments