Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions dags/multipod/maxtext_trillium_configs_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
from airflow.utils.task_group import TaskGroup
from dags import composer_env
from dags.common import test_owner
from dags.common.vm_resource import TpuVersion, Zone, Project, XpkClusters, DockerImage
from dags.common.vm_resource import Project, XpkClusters, DockerImage
from dags.common.model_configs import MaxTextTrilliumModelConfigs
from dags.multipod.configs import maxtext_sweep_gke_config
from dags.multipod.configs.common import SetupMode
from xlml.apis import metric_config, mlcompass
from xlml.apis import metric_config, mlcompass, task

# Run once a day at 3 am UTC (7 pm PST / 8 pm PDT)
CONIFGS_SCHEDULED_TIME = "0 9 * * *" if composer_env.is_prod_env() else None
Expand Down Expand Up @@ -66,7 +66,7 @@
quarantine_task_group = TaskGroup(
group_id="Quarantine", dag=dag, prefix_group_id=False
)
all_tests = []
all_tests: list[task.XpkTask] = []
for mode, image in DOCKER_IMAGES:
for model in MaxTextTrilliumModelConfigs:
# No tpu-recipe for DeepSeek v3
Expand All @@ -82,7 +82,11 @@
image = DockerImage.MAXTEXT_TPU_JAX_STABLE

base_run_model_cmds = [
f"python3 -m benchmarks.benchmark_runner on-device --base_output_directory={BASE_OUTPUT_DIRECTORY} --model_name={model.value} --libtpu_type=maxtext-docker --num_steps=15",
"bash preflight.sh",
f"python3 -m benchmarks.benchmark_runner on-device "
f"--base_output_directory={BASE_OUTPUT_DIRECTORY} "
f"--model_name={model.value} --libtpu_type=maxtext-docker "
f"--num_steps=15",
]
num_slices = (
[2]
Expand Down Expand Up @@ -110,16 +114,16 @@
enable_profile_config=enable_profile_config,
)
)
all_tests += maxtext_sweep_gke_test
all_tests.extend(maxtext_sweep_gke_test)

# Add dependencies between the tests so they are not all launched at once
mlcompass_scheduler = mlcompass.Scheduler()
chain_num = 4
chain_num = 16
prev = all_tests[0].run_with_name_gen_and_quarantine(quarantine_task_group)
mlcompass_scheduler.register(prev)
for i in range(1, len(all_tests)):
curr = all_tests[i].run_with_name_gen_and_quarantine(quarantine_task_group)
mlcompass_scheduler.register(curr)
if i % chain_num != 0:
prev >> curr
_ = prev >> curr
prev = curr
22 changes: 10 additions & 12 deletions dags/multipod/maxtext_v5e_configs_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from dags.common.model_configs import MaxTextV5eModelConfigs
from dags.multipod.configs import maxtext_sweep_gke_config
from dags.multipod.configs.common import SetupMode
from xlml.apis import metric_config
from xlml.apis import metric_config, task

# Run once a day at 4 am UTC (8 pm PST / 9 pm PDT)
SCHEDULED_TIME = "0 3 * * *" if composer_env.is_prod_env() else None
Expand Down Expand Up @@ -53,6 +53,7 @@
quarantine_task_group = TaskGroup(
group_id="Quarantine", dag=dag, prefix_group_id=False
)
all_tests: list[task.XpkTask] = []
for mode, image in DOCKER_IMAGES:
for model in MaxTextV5eModelConfigs:
base_run_model_cmds = [
Expand All @@ -78,18 +79,15 @@
sweep_params=QUANTIZATION_SWEEP,
)
)
all_tests.extend(maxtext_sweep_gke_test)

chain_num = 16
prev = maxtext_sweep_gke_test[0].run_with_name_gen_and_quarantine(
quarantine_task_group
)
for i in range(1, len(maxtext_sweep_gke_test)):
curr = maxtext_sweep_gke_test[i].run_with_name_gen_and_quarantine(
quarantine_task_group
)
if i % chain_num != 0:
_ = prev >> curr
prev = curr
chain_num = 16
prev = all_tests[0].run_with_name_gen_and_quarantine(quarantine_task_group)
for i in range(1, len(all_tests)):
curr = all_tests[i].run_with_name_gen_and_quarantine(quarantine_task_group)
if i % chain_num != 0:
_ = prev >> curr
prev = curr


# Run once a day at 10 am UTC (2 am PST / 3 am PDT)
Expand Down
27 changes: 14 additions & 13 deletions dags/multipod/maxtext_v5p_configs_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from dags.common.model_configs import MaxTextV5pModelConfigs
from dags.multipod.configs import maxtext_sweep_gke_config
from dags.multipod.configs.common import SetupMode
from xlml.apis import metric_config
from xlml.apis import metric_config, task

# Run once a day at 4 am UTC (8 pm PST / 9 pm PDT)
SCHEDULED_TIME = "0 1 * * *" if composer_env.is_prod_env() else None
Expand Down Expand Up @@ -53,11 +53,15 @@
quarantine_task_group = TaskGroup(
group_id="Quarantine", dag=dag, prefix_group_id=False
)
all_tests: list[task.XpkTask] = []
for mode, image in DOCKER_IMAGES:
for model in MaxTextV5pModelConfigs:
base_run_model_cmds = [
"bash preflight.sh",
f"python3 -m benchmarks.benchmark_runner on-device --base_output_directory={BASE_OUTPUT_DIRECTORY} --model_name={model.value} --libtpu_type=maxtext-docker --num_steps=15",
f"python3 -m benchmarks.benchmark_runner on-device "
f"--base_output_directory={BASE_OUTPUT_DIRECTORY} "
f"--model_name={model.value} --libtpu_type=maxtext-docker "
f"--num_steps=15",
]
maxtext_sweep_gke_test = (
maxtext_sweep_gke_config.get_maxtext_sweep_gke_config(
Expand All @@ -75,15 +79,12 @@
sweep_params=QUANTIZATION_SWEEP,
)
)
all_tests.extend(maxtext_sweep_gke_test)

chain_num = 4
prev = maxtext_sweep_gke_test[0].run_with_name_gen_and_quarantine(
quarantine_task_group
)
for i in range(1, len(maxtext_sweep_gke_test)):
curr = maxtext_sweep_gke_test[i].run_with_name_gen_and_quarantine(
quarantine_task_group
)
if i % chain_num != 0:
prev >> curr
prev = curr
chain_num = 16
prev = all_tests[0].run_with_name_gen_and_quarantine(quarantine_task_group)
for i in range(1, len(all_tests)):
curr = all_tests[i].run_with_name_gen_and_quarantine(quarantine_task_group)
if i % chain_num != 0:
_ = prev >> curr
prev = curr
Loading