diff --git a/dags/sparsity_diffusion_devx/jax_ai_image_tpu_e2e.py b/dags/sparsity_diffusion_devx/jax_ai_image_tpu_e2e.py index 7609d2c6b..525e32676 100644 --- a/dags/sparsity_diffusion_devx/jax_ai_image_tpu_e2e.py +++ b/dags/sparsity_diffusion_devx/jax_ai_image_tpu_e2e.py @@ -86,7 +86,7 @@ time_out_in_min=60, run_model_cmds=( f"JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true TPU_SLICE_BUILDER_DUMP_ICI=true JAX_FORCE_TPU_INIT=true ENABLE_TPUNETD_CLIENT=true && " - f"python -m MaxText.train MaxText/configs/base.yml run_name={slice_num}slice-V{cluster.device_version}_{cores}-maxtext-jax-stable-stack-{current_datetime} " + f"python -m MaxText.train maxtext/configs/base.yml run_name={slice_num}slice-V{cluster.device_version}_{cores}-maxtext-jax-stable-stack-{current_datetime} " "steps=30 per_device_batch_size=1 max_target_length=4096 model_name=llama2-7b " "enable_checkpointing=false attention=dot_product remat_policy=minimal_flash use_iota_embed=true scan_layers=false " "dataset_type=synthetic async_checkpointing=false "