diff --git a/ci/jax.sh b/ci/jax.sh index cc080916c..2adc204b5 100755 --- a/ci/jax.sh +++ b/ci/jax.sh @@ -71,13 +71,16 @@ run_test_config_mgpu() { *0.4.35*) # Workaround for distributed tests hang with xla_flag XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" run 3 test_distributed_fused_attn.py -k 'not test_context_parallel_ring_attn' + XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" NVTE_CK_USES_FWD_V3=1 NVTE_CK_USES_BWD_V3=1 test $_fus_attn = "ck" && run 3 test_distributed_fused_attn.py -k 'not test_context_parallel_ring_attn' # Test ring attention with xla_flag --xla_experimental_ignore_channel_id only XLA_FLAGS="--xla_experimental_ignore_channel_id" run_lbl "parallel_ring" 3 test_distributed_fused_attn.py -k test_context_parallel_ring_attn + XLA_FLAGS="--xla_experimental_ignore_channel_id" NVTE_CK_USES_FWD_V3=1 NVTE_CK_USES_BWD_V3=1 test $_fus_attn = "ck" && run_lbl "parallel_ring" 3 test_distributed_fused_attn.py -k test_context_parallel_ring_attn ;; *) # Workaround for distributed tests hang with xla_flag XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" run 3 test_distributed_fused_attn.py + XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" NVTE_CK_USES_FWD_V3=1 NVTE_CK_USES_BWD_V3=1 test $_fus_attn = "ck" && run 3 test_distributed_fused_attn.py ;; esac diff --git a/ci/pytorch.sh b/ci/pytorch.sh index 93b9ded7f..135b77c02 100755 --- a/ci/pytorch.sh +++ b/ci/pytorch.sh @@ -94,6 +94,7 @@ run_test_config_mgpu(){ run 3 distributed/test_numerics.py run 3 distributed/test_torch_fsdp2.py run 3 fused_attn/test_fused_attn_with_cp.py + NVTE_CK_USES_FWD_V3=1 NVTE_CK_USES_BWD_V3=1 run 3 fused_attn/test_fused_attn_with_cp.py fi }