diff --git a/README.md b/README.md index c163452..25e387f 100644 --- a/README.md +++ b/README.md @@ -153,6 +153,7 @@ pip install . --no-build-isolation ```python import torch from flash_dmattn import flash_dmattn_func_auto +from flash_dmattn.utils.mask import create_mask import math # Setup @@ -167,22 +168,20 @@ query = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dty key = torch.randn(batch_size, seq_len, num_kv_heads, head_dim, device=device, dtype=dtype) value = torch.randn(batch_size, seq_len, num_kv_heads, head_dim, device=device, dtype=dtype) -# Create mask and bias for sparse attention -attention_mask = torch.ones(batch_size, num_kv_heads, seq_len, seq_len, device=device, dtype=dtype) -attention_bias = torch.randn(batch_size, num_kv_heads, seq_len, seq_len, device=device, dtype=dtype) +# Create bias for sparse attention +attn_bias = torch.ones(batch_size, num_kv_heads, seq_len, seq_len, device=device, dtype=dtype) -# Generate sparse mask based on bias +# Generate dynamic mask based on bias if seq_len > window_size: - # Select top-k most important keys for each query - topk_values, topk_indices = torch.topk( - attention_bias, window_size, dim=-1, - largest=True, sorted=False + attn_mask = create_mask( + attention_bias=attn_bias, + attention_mask=None, + batch_size=batch_size, + query_len=seq_len, + key_len=seq_len, + window_size=window_size, + min_dtype=min_dtype, ) - # Generate valid top-k mask - valid_topk = (topk_values != min_dtype).to(dtype) - attention_mask = torch.zeros_like(attention_bias, dtype=dtype, device=attention_bias.device) - attention_mask = attention_mask.scatter(-1, topk_indices, valid_topk) - attention_bias = attention_bias.masked_fill(attention_mask == 0.0, min_dtype) # Select FDMA kernel flash_dmattn_func = flash_dmattn_func_auto(backend="cuda") @@ -192,8 +191,8 @@ output = flash_dmattn_func( query=query, key=key, value=value, - attn_mask=attention_mask, - attn_bias=attention_bias, + attn_mask=attn_mask, + attn_bias=attn_bias, is_causal=True, softmax_scale=1.0/math.sqrt(head_dim), ) @@ -208,13 +207,13 @@ print(f"Output shape: {output.shape}") # [1, 256, 2, 64] query.requires_grad_(True) key.requires_grad_(True) value.requires_grad_(True) -attention_bias.requires_grad_(True) +attn_bias.requires_grad_(True) # Forward pass output = flash_dmattn_func( query=query, key=key, value=value, - attn_mask=attention_mask, - attn_bias=attention_bias, + attn_mask=attn_mask, + attn_bias=attn_bias, is_causal=True, softmax_scale=1.0/math.sqrt(head_dim) ) @@ -226,7 +225,7 @@ loss.backward() print(f"Query gradient shape: {query.grad.shape}") print(f"Key gradient shape: {key.grad.shape}") print(f"Value gradient shape: {value.grad.shape}") -print(f"Bias gradient shape: {attention_bias.grad.shape}") +print(f"Bias gradient shape: {attn_bias.grad.shape}") ``` diff --git a/README_zh.md b/README_zh.md index 2550652..1990b58 100644 --- a/README_zh.md +++ b/README_zh.md @@ -153,6 +153,7 @@ pip install . --no-build-isolation ```python import torch from flash_dmattn import flash_dmattn_func_auto +from flash_dmattn.utils.mask import create_mask import math # 设置 @@ -167,22 +168,20 @@ query = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dty key = torch.randn(batch_size, seq_len, num_kv_heads, head_dim, device=device, dtype=dtype) value = torch.randn(batch_size, seq_len, num_kv_heads, head_dim, device=device, dtype=dtype) -# 为稀疏注意力创建 mask 和 bias -attention_mask = torch.ones(batch_size, num_kv_heads, seq_len, seq_len, device=device, dtype=dtype) -attention_bias = torch.randn(batch_size, num_kv_heads, seq_len, seq_len, device=device, dtype=dtype) +# 为稀疏注意力创建 bias +attn_bias = torch.randn(batch_size, num_kv_heads, seq_len, seq_len, device=device, dtype=dtype) -# 基于 bias 生成稀疏 mask +# 基于 bias 生成动态 mask if seq_len > window_size: - # 为每个查询选择 top-k 最重要的键 - topk_values, topk_indices = torch.topk( - attention_bias, window_size, dim=-1, - largest=True, sorted=False + attn_mask = create_mask( + attention_bias=attn_bias, + attention_mask=None, + batch_size=batch_size, + query_len=seq_len, + key_len=seq_len, + window_size=window_size, + min_dtype=min_dtype, ) - # 生成有效的 top-k mask - valid_topk = (topk_values != min_dtype).to(dtype) - attention_mask = torch.zeros_like(attention_bias, dtype=dtype, device=attention_bias.device) - attention_mask = attention_mask.scatter(-1, topk_indices, valid_topk) - attention_bias = attention_bias.masked_fill(attention_mask == 0.0, min_dtype) # 选择 FDMA 内核 flash_dmattn_func = flash_dmattn_func_auto(backend="cuda") @@ -192,8 +191,8 @@ output = flash_dmattn_func( query=query, key=key, value=value, - attn_mask=attention_mask, - attn_bias=attention_bias, + attn_mask=attn_mask, + attn_bias=attn_bias, is_causal=True, softmax_scale=1.0/math.sqrt(head_dim), ) @@ -208,13 +207,13 @@ print(f"输出形状: {output.shape}") # [1, 256, 2, 64] query.requires_grad_(True) key.requires_grad_(True) value.requires_grad_(True) -attention_bias.requires_grad_(True) +attn_bias.requires_grad_(True) # 前向传播 output = flash_dmattn_func( query=query, key=key, value=value, - attn_mask=attention_mask, - attn_bias=attention_bias, + attn_mask=attn_mask, + attn_bias=attn_bias, is_causal=True, softmax_scale=1.0/math.sqrt(head_dim) ) @@ -226,7 +225,7 @@ loss.backward() print(f"Query 梯度形状: {query.grad.shape}") print(f"Key 梯度形状: {key.grad.shape}") print(f"Value 梯度形状: {value.grad.shape}") -print(f"Bias 梯度形状: {attention_bias.grad.shape}") +print(f"Bias 梯度形状: {attn_bias.grad.shape}") ```