@@ -153,6 +153,7 @@ pip install . --no-build-isolation
153153``` python
154154import torch
155155from flash_dmattn import flash_dmattn_func_auto
156+ from flash_dmattn.utils.mask import create_mask
156157import math
157158
158159# Setup
@@ -167,22 +168,20 @@ query = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dty
167168key = torch.randn(batch_size, seq_len, num_kv_heads, head_dim, device = device, dtype = dtype)
168169value = torch.randn(batch_size, seq_len, num_kv_heads, head_dim, device = device, dtype = dtype)
169170
170- # Create mask and bias for sparse attention
171- attention_mask = torch.ones(batch_size, num_kv_heads, seq_len, seq_len, device = device, dtype = dtype)
172- attention_bias = torch.randn(batch_size, num_kv_heads, seq_len, seq_len, device = device, dtype = dtype)
171+ # Create bias for sparse attention
172+ attn_mask = torch.ones(batch_size, num_kv_heads, seq_len, seq_len, device = device, dtype = dtype)
173173
174- # Generate sparse mask based on bias
174+ # Generate dynamic mask based on bias
175175if seq_len > window_size:
176- # Select top-k most important keys for each query
177- topk_values, topk_indices = torch.topk(
178- attention_bias, window_size, dim = - 1 ,
179- largest = True , sorted = False
176+ attn_mask = create_mask(
177+ attention_bias = attn_bias,
178+ attention_mask = attn_mask,
179+ batch_size = batch_size,
180+ query_len = seq_len,
181+ key_len = seq_len,
182+ window_size = window_size,
183+ min_dtype = min_dtype,
180184 )
181- # Generate valid top-k mask
182- valid_topk = (topk_values != min_dtype).to(dtype)
183- attention_mask = torch.zeros_like(attention_bias, dtype = dtype, device = attention_bias.device)
184- attention_mask = attention_mask.scatter(- 1 , topk_indices, valid_topk)
185- attention_bias = attention_bias.masked_fill(attention_mask == 0.0 , min_dtype)
186185
187186# Select FDMA kernel
188187flash_dmattn_func = flash_dmattn_func_auto(backend = " cuda" )
@@ -192,8 +191,8 @@ output = flash_dmattn_func(
192191 query = query,
193192 key = key,
194193 value = value,
195- attn_mask = attention_mask ,
196- attn_bias = attention_bias ,
194+ attn_mask = attn_mask ,
195+ attn_bias = attn_bias ,
197196 is_causal = True ,
198197 softmax_scale = 1.0 / math.sqrt(head_dim),
199198)
@@ -208,13 +207,13 @@ print(f"Output shape: {output.shape}") # [1, 256, 2, 64]
208207query.requires_grad_(True )
209208key.requires_grad_(True )
210209value.requires_grad_(True )
211- attention_bias .requires_grad_(True )
210+ attn_bias .requires_grad_(True )
212211
213212# Forward pass
214213output = flash_dmattn_func(
215214 query = query, key = key, value = value,
216- attn_mask = attention_mask ,
217- attn_bias = attention_bias ,
215+ attn_mask = attn_mask ,
216+ attn_bias = attn_bias ,
218217 is_causal = True ,
219218 softmax_scale = 1.0 / math.sqrt(head_dim)
220219)
@@ -226,7 +225,7 @@ loss.backward()
226225print (f " Query gradient shape: { query.grad.shape} " )
227226print (f " Key gradient shape: { key.grad.shape} " )
228227print (f " Value gradient shape: { value.grad.shape} " )
229- print (f " Bias gradient shape: { attention_bias .grad.shape} " )
228+ print (f " Bias gradient shape: { attn_bias .grad.shape} " )
230229```
231230
232231
0 commit comments