@@ -68,6 +68,7 @@ def __init__(
6868 dim ,
6969 dim_head = 64 ,
7070 heads = 8 ,
71+ causal = True ,
7172 kv_heads = None
7273 ):
7374 super ().__init__ ()
@@ -78,6 +79,8 @@ def __init__(
7879 dim_inner = heads * dim_head
7980 dim_kv_inner = kv_heads * dim_head
8081
82+ self .causal = causal
83+
8184 self .rotary_embed = RotaryEmbedding (dim_head )
8285
8386 self .to_q = nn .Linear (dim , dim_inner , bias = False )
@@ -114,7 +117,7 @@ def forward(
114117
115118 out = F .scaled_dot_product_attention (
116119 q , k , v ,
117- is_causal = True
120+ is_causal = self . causal
118121 )
119122
120123 out = self .merge_heads (out )
@@ -146,6 +149,7 @@ def __init__(
146149 kv_heads = None ,
147150 ff_expansion_factor = 4. ,
148151 use_sparse_attn = False ,
152+ causal = True ,
149153 use_flex_sliding_window = False ,
150154 use_flex_fine_selection = False ,
151155 use_triton_fine_selection = False ,
@@ -164,6 +168,8 @@ def __init__(
164168 if use_flex_sliding_window or use_flex_fine_selection :
165169 assert exists (flex_attention ), 'flex attention is not available on your current version of pytorch'
166170
171+ self .causal = causal
172+
167173 self .use_sparse_attn = use_sparse_attn
168174 self .use_flex_sliding_window = use_sparse_attn & use_flex_sliding_window
169175 self .use_flex_fine_selection = use_sparse_attn & use_flex_fine_selection
@@ -177,6 +183,7 @@ def __init__(
177183 dim_head = dim_head ,
178184 heads = heads ,
179185 kv_heads = kv_heads ,
186+ causal = causal ,
180187 use_triton_kernel = use_triton_fine_selection ,
181188 ** sparse_attn_kwargs
182189 )
@@ -185,6 +192,7 @@ def __init__(
185192 dim = dim ,
186193 dim_head = dim_head ,
187194 heads = heads ,
195+ causal = causal ,
188196 kv_heads = kv_heads
189197 )
190198
@@ -275,12 +283,12 @@ def forward(
275283
276284 if not disable_flex and self .use_flex_sliding_window :
277285 attn_kwargs .update (
278- sliding_window_flex_mask = create_sliding_mask (seq_len , self .attn_sliding_window_size )
286+ sliding_window_flex_mask = create_sliding_mask (seq_len , self .attn_sliding_window_size , causal = self . causal )
279287 )
280288
281289 if not disable_flex and self .use_flex_fine_selection :
282290 attn_kwargs .update (
283- fine_selection_flex_mask = create_fine_mask (seq_len , self .attn_fine_block_size )
291+ fine_selection_flex_mask = create_fine_mask (seq_len , self .attn_fine_block_size , causal = self . causal )
284292 )
285293
286294 # cache
0 commit comments