Skip to content

Commit 464baf7

Browse files
authored
Merge pull request #190 from SmallDoges/fix-189
[FEATURE SUPPORT] Broadcastable 4D mask/bias, 128‑rounded key length, stride‑0 broadcasting, and dbias reductions
2 parents b500c36 + 08392c8 commit 464baf7

File tree

8 files changed

+847
-1090
lines changed

8 files changed

+847
-1090
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ Flash-DMA is a high-performance attention implementation that integrates Flash A
1818
## Key Features
1919

2020
### 🎯 Core Kernel Advantages
21-
- **Mask & Bias Support**: Native support for `(batch_size, {1|num_kv_heads|num_heads}, {0|query_len}, key_len)` shaped attention mask and attention bias tensors
21+
- **Mask & Bias Support**: Native support for `({1|batch_size}, {1|num_kv_heads|num_heads}, {1|query_len}, {1|key_len})` shaped attention mask and attention bias tensors
2222
- **Intelligent Computation Skipping**: Block-level automatic skipping mechanism based on masks, completely bypassing computation and memory access for zero-mask blocks
2323
- **Complete Gradient Support**: Built-in full gradient computation path for attention bias, supporting end-to-end training
2424

@@ -236,9 +236,9 @@ Flash-DMA integrates the efficient memory access patterns of Flash Attention wit
236236

237237
### Core Technology Integration
238238

239-
- **🎯 Native Mask & Bias Support**: Kernels directly process `(batch_size, {1|num_kv_heads|num_heads}, {0|query_len}, key_len)` shaped tensors
239+
- **🎯 Native Mask & Bias Support**: Kernels directly process `({1|batch_size}, {1|num_kv_heads|num_heads}, {1|query_len}, {1|key_len})` shaped tensors
240240
- **⚡ Block-level Intelligent Skipping**: Unified OR-reduction skipping logic based on masks, completely avoiding computation and memory access for zero blocks
241-
- **🔄 Complete Gradient Chain**: Built-in attention bias gradient computation (dbias) supporting end-to-end differentiable training
241+
- **🔄 Complete Gradient Chain**: Built-in attention bias gradient computation supporting end-to-end differentiable training
242242

243243
### Key Optimization Strategies
244244

README_zh.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ Flash-DMA 是一个高性能的注意力实现,将 Flash Attention 的内存
1818
## 主要特性
1919

2020
### 🎯 核心内核优势
21-
- **Mask & Bias 支持**: 原生支持 `(batch_size, {1|num_kv_heads|num_heads}, {0|query_len}, key_len)` 形状的 attention_mask 和 attention_bias 张量
21+
- **Mask & Bias 支持**: 原生支持 `({1|batch_size}, {1|num_kv_heads|num_heads}, {1|query_len}, {1|key_len})` 形状的 attention_mask 和 attention_bias 张量
2222
- **智能计算跳过**: 基于 attention_mask 的 block-level 自动跳过机制,完全跳过全零 mask 区块的计算和内存访问
2323
- **完整梯度支持**: 内置 attention_bias 的完整梯度计算路径,支持端到端训练
2424

@@ -236,7 +236,7 @@ Flash-DMA 通过将 Flash Attention 的高效内存访问模式与动态掩码
236236

237237
### 核心技术融合
238238

239-
- **🎯 Mask & Bias 原生支持**: 内核直接处理 `(batch_size, {1|num_kv_heads|num_heads}, {0|query_len}, key_len)` 形状的张量
239+
- **🎯 Mask & Bias 原生支持**: 内核直接处理 `({1|batch_size}, {1|num_kv_heads|num_heads}, {1|query_len}, {1|key_len})` 形状的张量
240240
- **⚡ Block-level 智能跳过**: 基于 mask 的统一 OR-reduction 跳过逻辑,完全避免全零区块的计算和内存访问
241241
- **🔄 完整梯度链路**: 内置 attention bias 梯度计算,支持端到端可微分训练
242242

0 commit comments

Comments
 (0)