Skip to content

Commit 9eaaf4c

Browse files
authored
Enable AITER V3 kernels by default (#372)
1 parent c95f9db commit 9eaaf4c

File tree

6 files changed

+47
-22
lines changed

6 files changed

+47
-22
lines changed

README.rst

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -264,15 +264,15 @@ Note that when using `THD` format tensors with CK Fused Attention, one should pa
264264
to indicate that there is no padding between sequences. Otherwise, passing proper tensors will indicate padding between sequences. This is the case
265265
for both the `FusedAttention` and `DotProductAttention` modules.
266266

267-
FA v3 Kernels in CK Backend
267+
AITER FA v3 Kernels
268268
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
269-
ROCm TE provides experimental support for flash-attention v3 fwd/bwd kernels using the ck backend for limited fused attention configs.
270-
To enable FA v3 kernels, the following environment variables can be used:
269+
ROCm TE supports flash-attention v3 fwd/bwd kernels on gfx942 and gfx950 using AITER backend.
270+
This functionality can be controlled by the following environment variables:
271271

272-
* NVTE_CK_USES_FWD_V3 - by default 0, if set to 1, some cases will call the fwd v3 kernel, only applicable to the gfx942 architecture;
273-
* NVTE_CK_USES_BWD_V3 - by default 0, if set to 1, some cases will call the bwd v3 dqdkdv kernel;
274-
* NVTE_CK_IS_V3_ATOMIC_FP32 - by default 1, if set to 0 will use atomic fp16/bf16(w/o convert_dq kernel) in bwd pass when NVTE_CK_USES_BWD_V3 is set to 1;
275-
* NVTE_CK_HOW_V3_BF16_CVT - by default 1, float to bf16 convert type when bwd_v3 is set to 1, 0:RTNE; 1:RTNA; 2:RTZ, only applicable to the gfx942 architecture.
272+
* NVTE_CK_USES_FWD_V3 - by default 1, if set to 0, v3 kernels will not be used for fwd pass;
273+
* NVTE_CK_USES_BWD_V3 - by default 1, if set to 0, v3 kernels will not be used for bwd pass;
274+
* NVTE_CK_IS_V3_ATOMIC_FP32 - by default 1, if set to 0 will use atomic fp16/bf16(w/o convert_dq kernel) in bwd pass when v3 is enabled;
275+
* NVTE_CK_HOW_V3_BF16_CVT - by default 1, float to bf16 convert type when v3 is enabled, 0:RTNE; 1:RTNA; 2:RTZ, only applicable to the gfx942 architecture.
276276

277277
Float to BFloat16 Conversion in CK Backend (gfx942 only)
278278
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

benchmarks/attention/benchmark_attention_rocm.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -104,19 +104,18 @@ def cleanup_env():
104104
for var in ATTENTION_ENV_VARS:
105105
os.environ[var] = "0"
106106

107-
def setup_backend_env(backend_name, use_ck_bwd_v3=False, use_ck_fwd_v3=False, use_ck_v3_a16=False):
107+
def setup_backend_env(backend_name, use_ck_bwd_v3=True, use_ck_fwd_v3=True, use_ck_v3_a16=False):
108108
cleanup_env()
109109

110110
if backend_name == "flash":
111111
os.environ["NVTE_FLASH_ATTN"] = "1"
112112
elif backend_name == "fused_ck":
113113
os.environ["NVTE_FUSED_ATTN"] = "1"
114114
os.environ["NVTE_FUSED_ATTN_CK"] = "1"
115+
os.environ["NVTE_CK_USES_BWD_V3"] = "1" if use_ck_bwd_v3 else "0"
115116
if use_ck_bwd_v3:
116-
os.environ["NVTE_CK_USES_BWD_V3"] = "1"
117117
os.environ["NVTE_CK_IS_V3_ATOMIC_FP32"] = "0" if use_ck_v3_a16 else "1"
118-
if use_ck_fwd_v3:
119-
os.environ["NVTE_CK_USES_FWD_V3"] = "1"
118+
os.environ["NVTE_CK_USES_FWD_V3"] = "1" if use_ck_fwd_v3 else "0"
120119
elif backend_name == "fused_aotriton":
121120
os.environ["NVTE_FUSED_ATTN"] = "1"
122121
os.environ["NVTE_FUSED_ATTN_AOTRITON"] = "1"
@@ -359,7 +358,7 @@ def main(args):
359358
print(
360359
f"Device {device_id}: "
361360
f"{device_properties.name} GPU, "
362-
f"sm{device_properties.major}{device_properties.minor} compute capability, "
361+
f"{device_properties.gcnArchName.split(':')[0]} architecture, "
363362
f"{device_properties.total_memory/1024**3:.1f}GB memory"
364363
)
365364
# Benchmarking starts..
@@ -438,8 +437,8 @@ def main(args):
438437

439438
if __name__ == "__main__":
440439
parser = argparse.ArgumentParser()
441-
parser.add_argument("--use_ck_bwd_v3", action="store_true", help="Use NVTE_CK_USES_BWD_V3=1 for CK bwd kernels")
442-
parser.add_argument("--use_ck_fwd_v3", action="store_true", help="Use NVTE_CK_USES_FWD_V3=1 for CK fwd kernels")
440+
parser.add_argument("--no_ck_bwd_v3", action="store_false", dest="use_ck_bwd_v3", help="Set NVTE_CK_USES_BWD_V3=0 for CK bwd kernels")
441+
parser.add_argument("--no_ck_fwd_v3", action="store_false", dest="use_ck_fwd_v3", help="Set NVTE_CK_USES_FWD_V3=0 for CK fwd kernels")
443442
parser.add_argument("--use_ck_v3_a16", action="store_true", help="Use NVTE_CK_IS_V3_ATOMIC_FP32=0 for atomic16. Default is 1")
444443
parser.add_argument("--run_sanity_checks", action="store_true", help="After benchmarking, verify profiler outputs.")
445444
args = parser.parse_args()

ci/jax.sh

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,18 @@ run_lbl() {
5050
_test_label=""
5151
}
5252

53+
run_default_fa_lbl() {
54+
if [ $_fus_attn = "$_DEFAULT_FUSED_ATTN" ]; then
55+
run_lbl "$@"
56+
fi
57+
}
58+
5359
run_test_config() {
5460
echo ==== Run with Fused attention backend: $_fus_attn ====
5561
run_default_fa 1 test_custom_call_compute.py
5662
run_default_fa 1 test_functions.py
5763
run 1 test_fused_attn.py
58-
NVTE_CK_USES_FWD_V3=1 NVTE_CK_USES_BWD_V3=1 run_lbl "v3" 1 test_fused_attn.py # Using FAv3 for forward and backward pass
64+
NVTE_CK_USES_FWD_V3=0 NVTE_CK_USES_BWD_V3=0 run_default_fa_lbl "v2" 3 test_fused_attn.py # Using FAv2 for forward and backward pass
5965
run_default_fa 1 test_helper.py
6066
run_default_fa 1 test_layer.py #it effectevly always uses unfused attention
6167
run_default_fa 1 test_sanity_import.py

ci/pytorch.sh

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ run_test_config(){
7272
run 1 test_sanity.py
7373
run_default_fa 1 test_sanity_import.py
7474
run_default_fa 1 fused_attn/test_fused_attn.py # Backend selection is controlled by the test
75-
NVTE_CK_USES_FWD_V3=1 NVTE_CK_USES_BWD_V3=1 run_default_fa_lbl "v3" 1 fused_attn/test_fused_attn.py # Using FAv3 for forward and backward pass
7675
run_default_fa 1 triton_kernels/test_cast.py
7776
run_default_fa 1 triton_kernels/test_cast_mxfp8.py
7877
run_default_fa 1 triton_kernels/test_norm_common.py
@@ -113,7 +112,7 @@ run_benchmark() {
113112
return
114113
fi
115114

116-
python "$BENCH_SCRIPT" --use_ck_fwd_v3 --use_ck_bwd_v3 --run_sanity_checks || test_run_error $BENCH_SCRIPT
115+
python "$BENCH_SCRIPT" --run_sanity_checks || test_run_error $BENCH_SCRIPT
117116
}
118117

119118
# Single config mode, run it and return result

tests/pytorch/fused_attn/test_fused_attn.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ def __del__(self):
9292
@pytest.fixture(autouse=True)
9393
def reset_attn_backend():
9494
env = EnvVarCleaner(["NVTE_FLASH_ATTN", "NVTE_FUSED_ATTN", "NVTE_UNFUSED_ATTN",
95-
"NVTE_FUSED_ATTN_CK", "NVTE_FUSED_ATTN_AOTRITON"])
95+
"NVTE_FUSED_ATTN_CK", "NVTE_FUSED_ATTN_AOTRITON",
96+
"NVTE_CK_USES_FWD_V3", "NVTE_CK_USES_BWD_V3"])
9697
yield
9798

9899

@@ -421,6 +422,8 @@ def test_dot_product_attention(
421422
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
422423
os.environ["NVTE_FUSED_ATTN_CK"] = "1"
423424
os.environ["NVTE_FUSED_ATTN_AOTRITON"] = "0"
425+
os.environ["NVTE_CK_USES_FWD_V3"] = "1"
426+
os.environ["NVTE_CK_USES_BWD_V3"] = "1"
424427
fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention(
425428
dtype,
426429
config,
@@ -432,8 +435,21 @@ def test_dot_product_attention(
432435
is_training,
433436
share_cu_seqlens_ref,
434437
)
435-
del os.environ["NVTE_FUSED_ATTN_CK"]
436-
del os.environ["NVTE_FUSED_ATTN_AOTRITON"]
438+
if IS_HIP_EXTENSION:
439+
os.environ["NVTE_CK_USES_FWD_V3"] = "0"
440+
os.environ["NVTE_CK_USES_BWD_V3"] = "0"
441+
fused_attn_fwd_2, fused_attn_bwd_2 = _run_dot_product_attention(
442+
dtype,
443+
config,
444+
"FusedAttention",
445+
ckpt_attn,
446+
qkv_layout,
447+
workspace_opt,
448+
pad_between_seqs,
449+
is_training,
450+
share_cu_seqlens_ref,
451+
)
452+
437453

438454
# FlashAttention backend
439455
if flash_attn_supported:
@@ -469,6 +485,11 @@ def test_dot_product_attention(
469485
torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols)
470486
for i, _ in enumerate(fused_attn_bwd):
471487
torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols)
488+
if IS_HIP_EXTENSION:
489+
logging.info("[test_dot_product_attention]: fused attn backend 0 vs 2")
490+
torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_2, **tols)
491+
for i, _ in enumerate(fused_attn_bwd):
492+
torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_2[i], **tols)
472493

473494

474495
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")

transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,7 @@ void fused_attn_ck_fwd_impl(
556556
if (env_p != nullptr && std::string(env_p) == "1")
557557
nvte_log_ck_config = true;
558558
}
559-
bool nvte_ck_uses_fwd_v3 = getenv<int>("NVTE_CK_USES_FWD_V3", 0);
559+
bool nvte_ck_uses_fwd_v3 = getenv<int>("NVTE_CK_USES_FWD_V3", 1);
560560

561561
bool is_ragged = nvte_get_qkv_format(layout)==NVTE_QKV_Format::NVTE_THD;
562562

@@ -1037,7 +1037,7 @@ void fused_attn_ck_bwd_impl(
10371037

10381038
// bwd v3 is optional by enabling the following envs
10391039
// default values follows the ck example setting
1040-
bool nvte_ck_uses_bwd_v3 = getenv<int>("NVTE_CK_USES_BWD_V3", 0);
1040+
bool nvte_ck_uses_bwd_v3 = getenv<int>("NVTE_CK_USES_BWD_V3", 1);
10411041
bool nvte_ck_is_v3_atomic_fp32 = getenv<int>("NVTE_CK_IS_V3_ATOMIC_FP32", 1);
10421042
int nvte_ck_how_v3_bf16_cvt = getenv<int>("NVTE_CK_HOW_V3_BF16_CVT", 1);
10431043
if (nvte_log_ck_config) {

0 commit comments

Comments
 (0)