Skip to content

Commit 50e896e

Browse files
authored
fix: Prevent DEADLINE_EXCEEDED errors in multi-slice TPU jobs (#1166)
Multi-slice TPU tasks were failing with `DEADLINE_EXCEEDED` and `Heartbeat` timeouts. Log analysis showed an observed 327s compilation window which exceeded the default JAX RPC timeout, causing nodes to be dropped before the training loop started.
1 parent d0d00f7 commit 50e896e

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

dags/sparsity_diffusion_devx/jax_ai_image_tpu_e2e.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@
107107
cluster=cluster,
108108
time_out_in_min=60,
109109
run_model_cmds=(
110+
"export JAX_COORDINATION_SERVICE_HEARTBEAT_TIMEOUT_SECONDS=1200 "
111+
"JAX_ENABLE_COMPILATION_CACHE=false "
110112
f"JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true TPU_SLICE_BUILDER_DUMP_ICI=true JAX_FORCE_TPU_INIT=true ENABLE_TPUNETD_CLIENT=true && "
111113
f"pip install . && python src/maxdiffusion/train_sdxl.py src/maxdiffusion/configs/base_xl.yml "
112114
f"pretrained_model_name_or_path=gs://maxdiffusion-github-runner-test-assets/checkpoints/models--stabilityai--stable-diffusion-xl-base-1.0 "

0 commit comments

Comments
 (0)