diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index 9b20836ff..d561fb887 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -195,17 +195,26 @@ def test(): if IS_HIP_EXTENSION: backends = {"AOTriton": "AOTRITON", "CK": "CK"} + # record the backend enablement from envs + backend_enabled = {} + for i in backends.keys(): + backend_enabled[backends[i]] = os.getenv("NVTE_FUSED_ATTN_"+backends[i], "1") with logging_context(): for i in backends.keys(): + # skip the already disabled backend + if int(backend_enabled[backends[i]])==0: + continue + # disable other backends for k in backends.keys(): os.environ["NVTE_FUSED_ATTN_"+backends[k]] = "0" os.environ["NVTE_FUSED_ATTN_"+backends[i]] = "1" _attention_backends["backend_selection_requires_update"] = True available_backends, fused_attention_backend = test() + # restore backend enablement envs + for k in backends.keys(): + os.environ["NVTE_FUSED_ATTN_"+backends[k]] = backend_enabled[backends[k]] if fused_attention_backend == FusedAttnBackend[i]: fused_attn_backends.append(fused_attention_backend) - for i in backends.keys(): - del os.environ["NVTE_FUSED_ATTN_"+backends[i]]; else: backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"} with logging_context():