Skip to content

Commit 91249cc

Browse files
Merge branch 'dev' into speedup-amax-kernel
2 parents 3c9de07 + 9eaaf4c commit 91249cc

File tree

6 files changed

+263
-189
lines changed

6 files changed

+263
-189
lines changed

README.rst

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -264,15 +264,15 @@ Note that when using `THD` format tensors with CK Fused Attention, one should pa
264264
to indicate that there is no padding between sequences. Otherwise, passing proper tensors will indicate padding between sequences. This is the case
265265
for both the `FusedAttention` and `DotProductAttention` modules.
266266

267-
FA v3 Kernels in CK Backend
267+
AITER FA v3 Kernels
268268
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
269-
ROCm TE provides experimental support for flash-attention v3 fwd/bwd kernels using the ck backend for limited fused attention configs.
270-
To enable FA v3 kernels, the following environment variables can be used:
269+
ROCm TE supports flash-attention v3 fwd/bwd kernels on gfx942 and gfx950 using AITER backend.
270+
This functionality can be controlled by the following environment variables:
271271

272-
* NVTE_CK_USES_FWD_V3 - by default 0, if set to 1, some cases will call the fwd v3 kernel, only applicable to the gfx942 architecture;
273-
* NVTE_CK_USES_BWD_V3 - by default 0, if set to 1, some cases will call the bwd v3 dqdkdv kernel;
274-
* NVTE_CK_IS_V3_ATOMIC_FP32 - by default 1, if set to 0 will use atomic fp16/bf16(w/o convert_dq kernel) in bwd pass when NVTE_CK_USES_BWD_V3 is set to 1;
275-
* NVTE_CK_HOW_V3_BF16_CVT - by default 1, float to bf16 convert type when bwd_v3 is set to 1, 0:RTNE; 1:RTNA; 2:RTZ, only applicable to the gfx942 architecture.
272+
* NVTE_CK_USES_FWD_V3 - by default 1, if set to 0, v3 kernels will not be used for fwd pass;
273+
* NVTE_CK_USES_BWD_V3 - by default 1, if set to 0, v3 kernels will not be used for bwd pass;
274+
* NVTE_CK_IS_V3_ATOMIC_FP32 - by default 1, if set to 0 will use atomic fp16/bf16(w/o convert_dq kernel) in bwd pass when v3 is enabled;
275+
* NVTE_CK_HOW_V3_BF16_CVT - by default 1, float to bf16 convert type when v3 is enabled, 0:RTNE; 1:RTNA; 2:RTZ, only applicable to the gfx942 architecture.
276276

277277
Float to BFloat16 Conversion in CK Backend (gfx942 only)
278278
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

0 commit comments

Comments
 (0)