@@ -21,6 +21,12 @@ install_prerequisites() {
2121 script_error " Failed to install Flax and dependencies"
2222 return $rc
2323 fi
24+ pip install pytest-timeout
25+ rc=$?
26+ if [ $rc -ne 0 ]; then
27+ script_error " Failed to install test prerequisites"
28+ exit $rc
29+ fi
2430}
2531
2632TEST_DIR=${TE_PATH} tests/jax
@@ -65,24 +71,15 @@ run_test_config() {
6571
6672run_test_config_mgpu () {
6773 echo ==== Run mGPU with Fused attention backend: $_fus_attn ====
68-
69- _ver=$( pip show jaxlib | grep Version)
70- case " $_ver " in
71- * 0.4.35* )
72- # Workaround for distributed tests hang with xla_flag
73- XLA_FLAGS=" --xla_gpu_enable_nccl_comm_splitting=false" run 3 test_distributed_fused_attn.py -k ' not test_context_parallel_ring_attn'
74-
75- # Test ring attention with xla_flag --xla_experimental_ignore_channel_id only
76- XLA_FLAGS=" --xla_experimental_ignore_channel_id" run_lbl " parallel_ring" 3 test_distributed_fused_attn.py -k test_context_parallel_ring_attn
77- ;;
78- * )
79- # Workaround for distributed tests hang with xla_flag
80- XLA_FLAGS=" --xla_gpu_enable_nccl_comm_splitting=false" run 3 test_distributed_fused_attn.py
81- ;;
82- esac
83-
74+
75+ # Mitigate distributed tests hang by adding 5min timeout
76+ _timeout_args=" --timeout 300 --timeout-method thread"
77+ # Workaround for some distributed tests hang/abotrion
78+ export XLA_FLAGS=" --xla_gpu_enable_nccl_comm_splitting=false"
79+
80+ run 3 test_distributed_fused_attn.py $_timeout_args
8481 run_default_fa 3 test_distributed_layernorm.py
85- XLA_FLAGS= " --xla_gpu_enable_nccl_comm_splitting=false " run_default_fa 3 test_distributed_layernorm_mlp.py
82+ run_default_fa 3 test_distributed_layernorm_mlp.py $_timeout_args
8683 run_default_fa 3 test_distributed_softmax.py
8784
8885 run_default_fa 3 test_sanity_import.py
0 commit comments