-
Notifications
You must be signed in to change notification settings - Fork 62
Expand file tree
/
Copy pathjax_stable_stack_tpu_e2e.py
More file actions
144 lines (132 loc) · 6.22 KB
/
jax_stable_stack_tpu_e2e.py
File metadata and controls
144 lines (132 loc) · 6.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A DAG to run end-to-end JAX Stable Stack TPU tests."""
import datetime
from airflow import models
from airflow.utils.task_group import TaskGroup
from dags import composer_env, gcs_bucket
from dags.common import test_owner
from dags.common.vm_resource import Project, TpuVersion, CpuVersion, Zone, DockerImage, GpuVersion, XpkClusters
from dags.sparsity_diffusion_devx.configs import gke_config as config
from dags.multipod.configs.common import SetupMode
from xlml.utils import name_format
# Run once a day at 3 am UTC (7 pm PST)
SCHEDULED_TIME = "0 3 * * *" if composer_env.is_prod_env() else None
with models.DAG(
dag_id="jax_stable_stack_tpu_e2e",
schedule=SCHEDULED_TIME,
tags=[
"sparsity_diffusion_devx",
"multipod_team",
"maxtext",
"maxdiffusion",
"axlearn",
"tpu",
"jax-stable-stack",
"mlscale_devx",
],
start_date=datetime.datetime(2024, 6, 7),
catchup=False,
) as dag:
current_datetime = config.get_current_datetime()
maxtext_test_configs = {
# accelerator: list of slices to test
"v4-16": [1, 2],
"v6e-256": [1],
}
maxdiffusion_test_configs = {
# accelerator: list of slices to test
"v4-8": [1],
"v6e-256": [1],
}
axlearn_test_configs = {
# accelerator: list of slices to test
"v4-16": [1],
}
quarantine_task_group = TaskGroup(
group_id="Quarantine", dag=dag, prefix_group_id=False
)
maxtext_docker_images = [
(SetupMode.STABLE, DockerImage.MAXTEXT_TPU_JAX_STABLE_STACK),
(SetupMode.NIGHTLY, DockerImage.MAXTEXT_TPU_STABLE_STACK_NIGHTLY_JAX),
]
maxdiffusion_docker_images = [
(
SetupMode.STABLE,
DockerImage.MAXDIFFUSION_TPU_JAX_STABLE_STACK,
),
(
SetupMode.NIGHTLY,
DockerImage.MAXDIFFUSION_TPU_STABLE_STACK_NIGHTLY_JAX,
),
]
for accelerator, slices in maxtext_test_configs.items():
cores = accelerator.rsplit("-", maxsplit=1)[-1]
cluster = config.clusters[accelerator]
for slice_num in slices:
for mode, image in maxtext_docker_images:
maxtext_jax_stable_stack_test = config.get_gke_config(
num_slices=slice_num,
cluster=cluster,
time_out_in_min=60,
run_model_cmds=(
f"JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true TPU_SLICE_BUILDER_DUMP_ICI=true JAX_FORCE_TPU_INIT=true ENABLE_TPUNETD_CLIENT=true && "
f"python -m MaxText.train MaxText/configs/base.yml run_name={slice_num}slice-V{cluster.device_version}_{cores}-maxtext-jax-stable-stack-{current_datetime} "
"steps=30 per_device_batch_size=1 max_target_length=4096 model_name=llama2-7b "
"enable_checkpointing=false attention=dot_product remat_policy=minimal_flash use_iota_embed=true scan_layers=false "
"dataset_type=synthetic async_checkpointing=false "
f"base_output_directory={gcs_bucket.BASE_OUTPUT_DIR}/maxtext/jax-stable-stack/automated/{current_datetime}",
),
test_name=f"maxtext-jax-stable-stack-{mode.value}-{accelerator}-{slice_num}x",
docker_image=image.value,
test_owner=test_owner.PARAM_B,
).run_with_quarantine(quarantine_task_group)
for accelerator, slices in maxdiffusion_test_configs.items():
cores = accelerator.rsplit("-", maxsplit=1)[-1]
cluster = config.clusters[accelerator]
for slice_num in slices:
for mode, image in maxdiffusion_docker_images:
maxdiffusion_jax_stable_stack_test = config.get_gke_config(
num_slices=slice_num,
cluster=cluster,
time_out_in_min=60,
run_model_cmds=(
f"JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true TPU_SLICE_BUILDER_DUMP_ICI=true JAX_FORCE_TPU_INIT=true ENABLE_TPUNETD_CLIENT=true && "
f"pip install . && python src/maxdiffusion/train.py src/maxdiffusion/configs/base_2_base.yml "
f"run_name={slice_num}slice-V{cluster.device_version}_{cores}-maxdiffusion-jax-stable-stack-{current_datetime} "
f"output_dir={gcs_bucket.BASE_OUTPUT_DIR}/maxdiffusion-jax-stable-stack-{mode.value}-{accelerator}-{slice_num}/automated/{current_datetime}",
),
test_name=f"maxdiffusion-jax-stable-stack-{mode.value}-{accelerator}-{slice_num}x",
docker_image=DockerImage.MAXDIFFUSION_TPU_JAX_STABLE_STACK.value,
test_owner=test_owner.PARAM_B,
).run_with_quarantine(quarantine_task_group)
for accelerator, slices in axlearn_test_configs.items():
cores = accelerator.rsplit("-", maxsplit=1)[-1]
cluster = config.clusters[accelerator]
for slice_num in slices:
axlearn_jax_stable_stack_test = config.get_gke_config(
num_slices=slice_num,
cluster=cluster,
time_out_in_min=300,
run_model_cmds=(
"JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true TPU_SLICE_BUILDER_DUMP_ICI=true JAX_FORCE_TPU_INIT=true ENABLE_TPUNETD_CLIENT=true && "
"cd axlearn && python -m axlearn.common.launch_trainer_main "
f"--module=text.gpt.c4_trainer --config=fuji-test-v1 "
f"--trainer_dir={gcs_bucket.BASE_OUTPUT_DIR}/bite/jax-stable-stack/automated/{current_datetime} "
f"--data_dir={gcs_bucket.AXLEARN_DIR} --jax_backend=tpu ",
),
test_name=f"axlearn-jax-stable-stack-{accelerator}-{slice_num}x",
docker_image=DockerImage.AXLEARN_TPU_JAX_STABLE_STACK.value,
test_owner=test_owner.PARAM_B,
).run_with_quarantine(quarantine_task_group)