From 96a0158915258a34c6f61a77895f5b427e4ccb08 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 4 Nov 2025 15:40:44 -0600 Subject: [PATCH 1/3] Add AITER ASM enabled distributed FA testing in jax/torch --- ci/jax.sh | 1 + ci/pytorch.sh | 1 + 2 files changed, 2 insertions(+) diff --git a/ci/jax.sh b/ci/jax.sh index cc080916c..c14446588 100755 --- a/ci/jax.sh +++ b/ci/jax.sh @@ -78,6 +78,7 @@ run_test_config_mgpu() { *) # Workaround for distributed tests hang with xla_flag XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" run 3 test_distributed_fused_attn.py + XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" NVTE_CK_USES_FWD_V3=1 NVTE_CK_USES_BWD_V3=1 run 3 test_distributed_fused_attn.py ;; esac diff --git a/ci/pytorch.sh b/ci/pytorch.sh index 207949ee5..164bc761d 100755 --- a/ci/pytorch.sh +++ b/ci/pytorch.sh @@ -94,6 +94,7 @@ run_test_config_mgpu(){ run 3 distributed/test_numerics.py run 3 distributed/test_torch_fsdp2.py run 3 fused_attn/test_fused_attn_with_cp.py + NVTE_CK_USES_FWD_V3=1 NVTE_CK_USES_BWD_V3=1 run 3 fused_attn/test_fused_attn_with_cp.py fi } From 2999f9fb6edfdd1a1fe97f8969ccaeac7c59eb13 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 5 Nov 2025 14:41:39 -0600 Subject: [PATCH 2/3] Added V3 dist runs for older JAX versions --- ci/jax.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ci/jax.sh b/ci/jax.sh index c14446588..2236baef1 100755 --- a/ci/jax.sh +++ b/ci/jax.sh @@ -71,9 +71,11 @@ run_test_config_mgpu() { *0.4.35*) # Workaround for distributed tests hang with xla_flag XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" run 3 test_distributed_fused_attn.py -k 'not test_context_parallel_ring_attn' + XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" NVTE_CK_USES_FWD_V3=1 NVTE_CK_USES_BWD_V3=1 run 3 test_distributed_fused_attn.py -k 'not test_context_parallel_ring_attn' # Test ring attention with xla_flag --xla_experimental_ignore_channel_id only XLA_FLAGS="--xla_experimental_ignore_channel_id" run_lbl "parallel_ring" 3 test_distributed_fused_attn.py -k test_context_parallel_ring_attn + XLA_FLAGS="--xla_experimental_ignore_channel_id" NVTE_CK_USES_FWD_V3=1 NVTE_CK_USES_BWD_V3=1 run_lbl "parallel_ring" 3 test_distributed_fused_attn.py -k test_context_parallel_ring_attn ;; *) # Workaround for distributed tests hang with xla_flag From 3b072fff86326d96c2e772dea0f86960abb2b337 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Mon, 10 Nov 2025 12:42:44 -0600 Subject: [PATCH 3/3] Guard CK V3 tests based on backend --- ci/jax.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ci/jax.sh b/ci/jax.sh index 2236baef1..2adc204b5 100755 --- a/ci/jax.sh +++ b/ci/jax.sh @@ -71,16 +71,16 @@ run_test_config_mgpu() { *0.4.35*) # Workaround for distributed tests hang with xla_flag XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" run 3 test_distributed_fused_attn.py -k 'not test_context_parallel_ring_attn' - XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" NVTE_CK_USES_FWD_V3=1 NVTE_CK_USES_BWD_V3=1 run 3 test_distributed_fused_attn.py -k 'not test_context_parallel_ring_attn' + XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" NVTE_CK_USES_FWD_V3=1 NVTE_CK_USES_BWD_V3=1 test $_fus_attn = "ck" && run 3 test_distributed_fused_attn.py -k 'not test_context_parallel_ring_attn' # Test ring attention with xla_flag --xla_experimental_ignore_channel_id only XLA_FLAGS="--xla_experimental_ignore_channel_id" run_lbl "parallel_ring" 3 test_distributed_fused_attn.py -k test_context_parallel_ring_attn - XLA_FLAGS="--xla_experimental_ignore_channel_id" NVTE_CK_USES_FWD_V3=1 NVTE_CK_USES_BWD_V3=1 run_lbl "parallel_ring" 3 test_distributed_fused_attn.py -k test_context_parallel_ring_attn + XLA_FLAGS="--xla_experimental_ignore_channel_id" NVTE_CK_USES_FWD_V3=1 NVTE_CK_USES_BWD_V3=1 test $_fus_attn = "ck" && run_lbl "parallel_ring" 3 test_distributed_fused_attn.py -k test_context_parallel_ring_attn ;; *) # Workaround for distributed tests hang with xla_flag XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" run 3 test_distributed_fused_attn.py - XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" NVTE_CK_USES_FWD_V3=1 NVTE_CK_USES_BWD_V3=1 run 3 test_distributed_fused_attn.py + XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" NVTE_CK_USES_FWD_V3=1 NVTE_CK_USES_BWD_V3=1 test $_fus_attn = "ck" && run 3 test_distributed_fused_attn.py ;; esac