Skip to content

Commit 08bb25e

Browse files
authored
Add pytest timeout to mitigate JAX tests hang (#395)
1 parent 9aa2101 commit 08bb25e

File tree

1 file changed

+14
-17
lines changed

1 file changed

+14
-17
lines changed

ci/jax.sh

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@ install_prerequisites() {
2121
script_error "Failed to install Flax and dependencies"
2222
return $rc
2323
fi
24+
pip install pytest-timeout
25+
rc=$?
26+
if [ $rc -ne 0 ]; then
27+
script_error "Failed to install test prerequisites"
28+
exit $rc
29+
fi
2430
}
2531

2632
TEST_DIR=${TE_PATH}tests/jax
@@ -65,24 +71,15 @@ run_test_config() {
6571

6672
run_test_config_mgpu() {
6773
echo ==== Run mGPU with Fused attention backend: $_fus_attn ====
68-
69-
_ver=$(pip show jaxlib | grep Version)
70-
case "$_ver" in
71-
*0.4.35*)
72-
# Workaround for distributed tests hang with xla_flag
73-
XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" run 3 test_distributed_fused_attn.py -k 'not test_context_parallel_ring_attn'
74-
75-
# Test ring attention with xla_flag --xla_experimental_ignore_channel_id only
76-
XLA_FLAGS="--xla_experimental_ignore_channel_id" run_lbl "parallel_ring" 3 test_distributed_fused_attn.py -k test_context_parallel_ring_attn
77-
;;
78-
*)
79-
# Workaround for distributed tests hang with xla_flag
80-
XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" run 3 test_distributed_fused_attn.py
81-
;;
82-
esac
83-
74+
75+
# Mitigate distributed tests hang by adding 5min timeout
76+
_timeout_args="--timeout 300 --timeout-method thread"
77+
# Workaround for some distributed tests hang/abotrion
78+
export XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false"
79+
80+
run 3 test_distributed_fused_attn.py $_timeout_args
8481
run_default_fa 3 test_distributed_layernorm.py
85-
XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" run_default_fa 3 test_distributed_layernorm_mlp.py
82+
run_default_fa 3 test_distributed_layernorm_mlp.py $_timeout_args
8683
run_default_fa 3 test_distributed_softmax.py
8784

8885
run_default_fa 3 test_sanity_import.py

0 commit comments

Comments
 (0)