@@ -67,17 +67,24 @@ def compress_mask(_, __, q_idx, kv_idx):
6767
6868
6969def create_fine_mask (selected_block_indices : Tensor , seq_len , fine_block_size ):
70+ device = selected_block_indices .device
7071 batch , heads = selected_block_indices .shape [:2 ]
7172
73+ one_hot_selected_block_indices = torch .zeros ((* selected_block_indices .shape [:- 1 ], seq_len // fine_block_size ), device = device , dtype = torch .bool )
74+ one_hot_selected_block_indices .scatter_ (- 1 , selected_block_indices , True )
75+
7276 def fine_mask (b_idx , h_idx , q_idx , kv_idx ):
73- selected_indices = selected_block_indices [b_idx , h_idx ]
7477
75- # todo - fill in logic for creating the selected kv ranges per query
78+ compressed_q_idx = q_idx // fine_block_size
79+ compressed_kv_idx = kv_idx // fine_block_size
80+
81+ block_causal_mask = compressed_q_idx > compressed_kv_idx
82+ is_selected = one_hot_selected_block_indices [b_idx , h_idx , q_idx , compressed_kv_idx ]
7683
7784 causal_mask = q_idx >= kv_idx
78- block_diagonal = ( q_idx // fine_block_size ) == ( kv_idx // fine_block_size )
85+ block_diagonal = compressed_q_idx == compressed_kv_idx
7986
80- return (block_diagonal & causal_mask )
87+ return (causal_mask & block_diagonal ) | ( block_causal_mask & is_selected )
8188
8289 block_mask = create_block_mask (fine_mask , B = batch , H = heads , Q_LEN = seq_len , KV_LEN = seq_len , _compile = True )
8390 return block_mask
0 commit comments