You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: README.rst
+7-7Lines changed: 7 additions & 7 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -264,15 +264,15 @@ Note that when using `THD` format tensors with CK Fused Attention, one should pa
264
264
to indicate that there is no padding between sequences. Otherwise, passing proper tensors will indicate padding between sequences. This is the case
265
265
for both the `FusedAttention` and `DotProductAttention` modules.
266
266
267
-
FA v3 Kernels in CK Backend
267
+
AITER FA v3 Kernels
268
268
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
269
-
ROCm TE provides experimental support for flash-attention v3 fwd/bwd kernels using the ck backend for limited fused attention configs.
270
-
To enable FA v3 kernels, the following environment variables can be used:
269
+
ROCm TE supports flash-attention v3 fwd/bwd kernels on gfx942 and gfx950 using AITER backend.
270
+
This functionality can be controlled by the following environment variables:
271
271
272
-
* NVTE_CK_USES_FWD_V3 - by default 0, if set to 1, some cases will call the fwd v3 kernel, only applicable to the gfx942 architecture;
273
-
* NVTE_CK_USES_BWD_V3 - by default 0, if set to 1, some cases will call the bwd v3 dqdkdv kernel;
274
-
* NVTE_CK_IS_V3_ATOMIC_FP32 - by default 1, if set to 0 will use atomic fp16/bf16(w/o convert_dq kernel) in bwd pass when NVTE_CK_USES_BWD_V3 is set to 1;
275
-
* NVTE_CK_HOW_V3_BF16_CVT - by default 1, float to bf16 convert type when bwd_v3 is set to 1, 0:RTNE; 1:RTNA; 2:RTZ, only applicable to the gfx942 architecture.
272
+
* NVTE_CK_USES_FWD_V3 - by default 1, if set to 0, v3 kernels will not be used for fwd pass;
273
+
* NVTE_CK_USES_BWD_V3 - by default 1, if set to 0, v3 kernels will not be used for bwd pass;
274
+
* NVTE_CK_IS_V3_ATOMIC_FP32 - by default 1, if set to 0 will use atomic fp16/bf16(w/o convert_dq kernel) in bwd pass when v3 is enabled;
275
+
* NVTE_CK_HOW_V3_BF16_CVT - by default 1, float to bf16 convert type when v3 is enabled, 0:RTNE; 1:RTNA; 2:RTZ, only applicable to the gfx942 architecture.
276
276
277
277
Float to BFloat16 Conversion in CK Backend (gfx942 only)
0 commit comments