diff --git a/dags/common/test_owner.py b/dags/common/test_owner.py index 4789e4a4c..02a98cdb2 100644 --- a/dags/common/test_owner.py +++ b/dags/common/test_owner.py @@ -65,6 +65,7 @@ class Team(enum.Enum): KUNJAN_P = "Kunjan P." MICHELLE_Y = "Michelle Y." SHUNING_J = "Shuning J." +ROHAN_B = "Rohan B." # Inference ANDY_Y = "Andy Y." diff --git a/dags/common/vm_resource.py b/dags/common/vm_resource.py index 885fdbfa8..8e48a98d2 100644 --- a/dags/common/vm_resource.py +++ b/dags/common/vm_resource.py @@ -361,8 +361,7 @@ class DockerImage(enum.Enum): f"{datetime.datetime.today().strftime('%Y-%m-%d')}" ) MAXDIFFUSION_TPU_JAX_STABLE_STACK_CANDIDATE = ( - "gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:" - f"{datetime.datetime.today().strftime('%Y-%m-%d')}" + "gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:latest" ) MAXTEXT_TPU_JAX_NIGHTLY = ( "gcr.io/tpu-prod-env-multipod/maxtext_jax_nightly:" diff --git a/dags/multipod/legacy.py b/dags/multipod/legacy.py index ce08799f8..9d96d4bcf 100644 --- a/dags/multipod/legacy.py +++ b/dags/multipod/legacy.py @@ -25,7 +25,7 @@ # Run once a day at 9 am UTC (1 am PST) SCHEDULED_TIME = "0 9 * * *" if composer_env.is_prod_env() else None DOCKER_IMAGE = { - SetupMode.STABLE: DockerImage.MAXTEXT_TPU_JAX_STABLE_STACK_CANDIDATE, + SetupMode.STABLE: DockerImage.MAXTEXT_TPU_JAX_STABLE_STACK, SetupMode.NIGHTLY: DockerImage.MAXTEXT_TPU_JAX_NIGHTLY, } diff --git a/dags/multipod/maxtext_checkpointing.py b/dags/multipod/maxtext_checkpointing.py index 2a3c4c929..fe07c5486 100644 --- a/dags/multipod/maxtext_checkpointing.py +++ b/dags/multipod/maxtext_checkpointing.py @@ -45,7 +45,7 @@ current_time = datetime.datetime.now() current_datetime = current_time.strftime("%Y-%m-%d-%H-%M-%S") docker_images = [ - (SetupMode.STABLE, DockerImage.MAXTEXT_TPU_JAX_STABLE_STACK_CANDIDATE), + (SetupMode.STABLE, DockerImage.MAXTEXT_TPU_JAX_STABLE_STACK), (SetupMode.NIGHTLY, DockerImage.MAXTEXT_TPU_JAX_NIGHTLY), ] test_configs = { diff --git a/dags/multipod/maxtext_convergence.py b/dags/multipod/maxtext_convergence.py index 17d46c102..1186ab8f2 100644 --- a/dags/multipod/maxtext_convergence.py +++ b/dags/multipod/maxtext_convergence.py @@ -76,7 +76,7 @@ time_out_in_min=300, test_name=test_name, run_model_cmds=run_command, - docker_image=DockerImage.MAXTEXT_TPU_JAX_STABLE_STACK_CANDIDATE.value, + docker_image=DockerImage.MAXTEXT_TPU_JAX_STABLE_STACK.value, test_owner=test_owner.MATT_D, base_output_directory=base_output_directory, metric_aggregation_strategy=metric_config.AggregationStrategy.LAST, diff --git a/dags/multipod/maxtext_end_to_end.py b/dags/multipod/maxtext_end_to_end.py index 9ba95dca0..a21ec3aa6 100644 --- a/dags/multipod/maxtext_end_to_end.py +++ b/dags/multipod/maxtext_end_to_end.py @@ -63,7 +63,7 @@ f"export HF_TOKEN={HF_TOKEN}", f"bash end_to_end/{test_script}.sh", ), - docker_image=DockerImage.MAXTEXT_TPU_JAX_STABLE_STACK_CANDIDATE.value, + docker_image=DockerImage.MAXTEXT_TPU_JAX_STABLE_STACK.value, cluster=XpkClusters.TPU_V5P_8_CLUSTER, test_owner=test_owner.MOHIT_K, ).run_with_quarantine(quarantine_task_group) @@ -146,7 +146,7 @@ def convert_checkpoint_and_run_training( return conversion_cpu, training_tpu docker_image = { - "stable": DockerImage.MAXTEXT_TPU_JAX_STABLE_STACK_CANDIDATE.value, + "stable": DockerImage.MAXTEXT_TPU_JAX_STABLE_STACK.value, "nightly": DockerImage.MAXTEXT_TPU_JAX_NIGHTLY.value, } tests = [] diff --git a/dags/multipod/maxtext_sft_trainer.py b/dags/multipod/maxtext_sft_trainer.py index a12e3388c..31fa7dc5e 100644 --- a/dags/multipod/maxtext_sft_trainer.py +++ b/dags/multipod/maxtext_sft_trainer.py @@ -36,7 +36,7 @@ ) as dag: base_output_directory = f'{gcs_bucket.BASE_OUTPUT_DIR}/maxtext_sft_trainer' docker_images = [ - (SetupMode.STABLE, DockerImage.MAXTEXT_TPU_JAX_STABLE_STACK_CANDIDATE), + (SetupMode.STABLE, DockerImage.MAXTEXT_TPU_JAX_STABLE_STACK), (SetupMode.NIGHTLY, DockerImage.MAXTEXT_TPU_JAX_NIGHTLY), ] diff --git a/dags/multipod/maxtext_trillium_configs_perf.py b/dags/multipod/maxtext_trillium_configs_perf.py index e0dde2c43..e0d241ad0 100644 --- a/dags/multipod/maxtext_trillium_configs_perf.py +++ b/dags/multipod/maxtext_trillium_configs_perf.py @@ -77,7 +77,7 @@ model in need_stable_candidate_set and image == DockerImage.MAXTEXT_TPU_JAX_STABLE_STACK ): - image = DockerImage.MAXTEXT_TPU_JAX_STABLE_STACK_CANDIDATE + image = DockerImage.MAXTEXT_TPU_JAX_STABLE_STACK base_run_model_cmds = [ f"python3 -m benchmarks.benchmark_runner on-device --base_output_directory={BASE_OUTPUT_DIRECTORY} --model_name={model.value} --libtpu_type=maxtext-docker --num_steps=15", diff --git a/dags/sparsity_diffusion_devx/jax_ai_image_candidate_e2e.py b/dags/sparsity_diffusion_devx/jax_ai_image_candidate_e2e.py new file mode 100644 index 000000000..6f283feaa --- /dev/null +++ b/dags/sparsity_diffusion_devx/jax_ai_image_candidate_e2e.py @@ -0,0 +1,113 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A DAG to run end-to-end JAX AI Image Candidate TPU tests before public release.""" + + +import datetime +from airflow import models +from airflow.utils.task_group import TaskGroup +from dags import composer_env, gcs_bucket +from dags.common import test_owner +from dags.common.vm_resource import Project, TpuVersion, CpuVersion, Zone, DockerImage, GpuVersion, XpkClusters +from dags.sparsity_diffusion_devx.configs import gke_config as config +from dags.multipod.configs.common import SetupMode +from xlml.utils import name_format + + +with models.DAG( + dag_id="jax_ai_image_candidate_tpu_e2e", + tags=[ + "sparsity_diffusion_devx", + "multipod_team", + "maxtext", + "maxdiffusion", + "tpu", + "jax-stable-stack", + "mlscale_devx", + ], + start_date=datetime.datetime(2025, 7, 24), + catchup=False, +) as dag: + current_datetime = config.get_current_datetime() + maxtext_test_configs = { + # accelerator: list of slices to test + "v4-16": [1], + "v5-8": [1, 2], + } + maxdiffusion_test_configs = { + # accelerator: list of slices to test + "v5-8": [1, 2], + } + + quarantine_task_group = TaskGroup( + group_id="Quarantine", dag=dag, prefix_group_id=False + ) + + maxtext_docker_images = [ + (SetupMode.STABLE, DockerImage.MAXTEXT_TPU_JAX_STABLE_STACK_CANDIDATE) + ] + + maxdiffusion_docker_images = [ + ( + SetupMode.STABLE, + DockerImage.MAXDIFFUSION_TPU_JAX_STABLE_STACK_CANDIDATE, + ) + ] + + for accelerator, slices in maxtext_test_configs.items(): + cores = accelerator.rsplit("-", maxsplit=1)[-1] + cluster = config.clusters[accelerator] + for slice_num in slices: + for mode, image in maxtext_docker_images: + maxtext_jax_stable_stack_test = config.get_gke_config( + num_slices=slice_num, + cluster=cluster, + time_out_in_min=60, + run_model_cmds=( + f"JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true TPU_SLICE_BUILDER_DUMP_ICI=true JAX_FORCE_TPU_INIT=true ENABLE_TPUNETD_CLIENT=true && " + f"python -m MaxText.train MaxText/configs/base.yml run_name={slice_num}slice-V{cluster.device_version}_{cores}-maxtext-jax-stable-stack-{current_datetime} " + "steps=30 per_device_batch_size=1 max_target_length=4096 model_name=llama2-7b " + "enable_checkpointing=false attention=dot_product remat_policy=minimal_flash use_iota_embed=true scan_layers=false " + "dataset_type=synthetic async_checkpointing=false " + f"base_output_directory={gcs_bucket.BASE_OUTPUT_DIR}/maxtext/jax-stable-stack/automated/{current_datetime}", + ), + test_name=f"maxtext-jax-stable-stack-{mode.value}-{accelerator}-{slice_num}x", + docker_image=image.value, + test_owner=test_owner.PARAM_B, + ).run_with_quarantine(quarantine_task_group) + + for accelerator, slices in maxdiffusion_test_configs.items(): + cores = accelerator.rsplit("-", maxsplit=1)[-1] + cluster = config.clusters[accelerator] + for slice_num in slices: + for mode, image in maxdiffusion_docker_images: + maxdiffusion_jax_stable_stack_test = config.get_gke_config( + num_slices=slice_num, + cluster=cluster, + time_out_in_min=60, + run_model_cmds=( + f"JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true TPU_SLICE_BUILDER_DUMP_ICI=true JAX_FORCE_TPU_INIT=true ENABLE_TPUNETD_CLIENT=true && " + f"pip install . && python src/maxdiffusion/train_sdxl.py src/maxdiffusion/configs/base_xl.yml " + f"pretrained_model_name_or_path=gs://maxdiffusion-github-runner-test-assets/checkpoints/models--stabilityai--stable-diffusion-xl-base-1.0 " + f"revision=refs/pr/95 activations_dtype=bfloat16 weights_dtype=bfloat16 " + f"dataset_name=jfacevedo-maxdiffusion-v5p/pokemon-datasets/pokemon-gpt4-captions_sdxl resolution=1024 per_device_batch_size=1 " + f"jax_cache_dir=gs://jfacevedo-maxdiffusion/cache_dir/ max_train_steps=20 attention=flash enable_profiler=True " + f"run_name={slice_num}slice-V{cluster.device_version}_{cores}-maxdiffusion-jax-stable-stack-{current_datetime} " + f"output_dir={gcs_bucket.BASE_OUTPUT_DIR}/maxdiffusion-jax-stable-stack-{mode.value}-{accelerator}-{slice_num}/automated/{current_datetime}", + ), + test_name=f"maxdiffusion-jax-stable-stack-sdxl-{mode.value}-{accelerator}-{slice_num}x", + docker_image=DockerImage.MAXDIFFUSION_TPU_JAX_STABLE_STACK_CANDIDATE.value, + test_owner=test_owner.PARAM_B, + ).run_with_quarantine(quarantine_task_group) diff --git a/dags/sparsity_diffusion_devx/jax_stable_stack_tpu_e2e.py b/dags/sparsity_diffusion_devx/jax_stable_stack_tpu_e2e.py index a31f6143a..467fa0a75 100644 --- a/dags/sparsity_diffusion_devx/jax_stable_stack_tpu_e2e.py +++ b/dags/sparsity_diffusion_devx/jax_stable_stack_tpu_e2e.py @@ -66,7 +66,7 @@ ) maxtext_docker_images = [ - (SetupMode.STABLE, DockerImage.MAXTEXT_TPU_JAX_STABLE_STACK_CANDIDATE), + (SetupMode.STABLE, DockerImage.MAXTEXT_TPU_JAX_STABLE_STACK), (SetupMode.NIGHTLY, DockerImage.MAXTEXT_TPU_STABLE_STACK_NIGHTLY_JAX), ] diff --git a/dags/sparsity_diffusion_devx/maxdiffusion_tpu_e2e.py b/dags/sparsity_diffusion_devx/maxdiffusion_tpu_e2e.py index 913d1e7e7..7b38f0358 100644 --- a/dags/sparsity_diffusion_devx/maxdiffusion_tpu_e2e.py +++ b/dags/sparsity_diffusion_devx/maxdiffusion_tpu_e2e.py @@ -111,7 +111,7 @@ f"output_dir={sdxl_base_output_dir}", ), test_name=sdxl_run_name_prefix, - docker_image=DockerImage.MAXDIFFUSION_TPU_JAX_STABLE_STACK.value, + docker_image=DockerImage.MAXDIFFUSION_TPU_JAX_STABLE_STACK_CANDIDATE.value, test_owner=test_owner.PARAM_B, tensorboard_summary_config=sdxl_tensorboard_summary_config, ).run_with_name_gen_and_quarantine( @@ -133,7 +133,7 @@ f"LOSS_THRESHOLD=100", ), test_name=sdxl_nan_run_name_prefix, - docker_image=DockerImage.MAXDIFFUSION_TPU_JAX_STABLE_STACK.value, + docker_image=DockerImage.MAXDIFFUSION_TPU_JAX_STABLE_STACK_CANDIDATE.value, test_owner=test_owner.PARAM_B, tensorboard_summary_config=sdxl_nan_tensorboard_summary_config, ).run_with_name_gen_and_quarantine( @@ -166,7 +166,7 @@ f"output_dir={sdv2_base_output_dir}", ), test_name=f"maxd-sdv2-{accelerator}-{slice_num}x", - docker_image=DockerImage.MAXDIFFUSION_TPU_JAX_STABLE_STACK.value, + docker_image=DockerImage.MAXDIFFUSION_TPU_JAX_STABLE_STACK_CANDIDATE.value, test_owner=test_owner.PARAM_B, ).run_with_quarantine(quarantine_task_group)