Skip to content

Commit cd6f51b

Browse files
authored
Merge pull request #669 from ROCm/tianxing/FA-int8
Tianxing/fa int8
2 parents 9cdcf1d + 296209a commit cd6f51b

File tree

2 files changed

+343
-40
lines changed

2 files changed

+343
-40
lines changed

python/perf-kernels/README.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,27 @@ This script contains the Flash Attention kernel with the following support
4242
- Multi and Grouped Query attention
4343
- ALiBi bias
4444
- Matrix bias
45+
- Int8 quantization
4546

4647
These are currently supported for the forward kernel only.
4748

49+
INT8 Quantization Support
50+
51+
1. <em>q_descale</em>, <em>k_descale</em>, and <em>v_descale</em> provided:
52+
- The first QK GEMM runs in INT8, then the output is dequantized to the specified <em>dtype</em>.
53+
- The second PV GEMM runs in the specified <em>dtype</em>.
54+
55+
2. <em>q_descale</em>, <em>k_descale</em>, <em>p_descale</em>, and <em>v_descale</em> provided:
56+
- Both the first and second GEMM operations run in INT8.
57+
- The results are dequantized to the specified <em>dtype</em> after both GEMMs.
58+
59+
3. Only <em>k_descale</em> and <em>v_descale</em> provided:
60+
- K and V are dequantized before the first and second GEMM operations, respectively.
61+
- Both GEMMs run in the specified <em>dtype</em>.
62+
63+
Note: The softmax operation is always performed in <em>fp32</em>.
64+
65+
4866
## `06-attention-decode.py`
4967

5068
This contains the Flash Decoding kernel.

0 commit comments

Comments
 (0)