Skip to content

Commit 7769c5a

Browse files
authored
fix: Adjust tasks concurrency to reduce parallelism (#1174)
This change reduce the test-level concurrency from 16 parallel tests to 4 parallel tests to avoid resource contention.
1 parent cab29a7 commit 7769c5a

File tree

3 files changed

+35
-32
lines changed

3 files changed

+35
-32
lines changed

dags/multipod/maxtext_trillium_configs_perf.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@
2020
from airflow.utils.task_group import TaskGroup
2121
from dags import composer_env
2222
from dags.common import test_owner
23-
from dags.common.vm_resource import TpuVersion, Zone, Project, XpkClusters, DockerImage
23+
from dags.common.vm_resource import Project, XpkClusters, DockerImage
2424
from dags.common.model_configs import MaxTextTrilliumModelConfigs
2525
from dags.multipod.configs import maxtext_sweep_gke_config
2626
from dags.multipod.configs.common import SetupMode
27-
from xlml.apis import metric_config, mlcompass
27+
from xlml.apis import metric_config, mlcompass, task
2828

2929
# Run once a day at 3 am UTC (7 pm PST / 8 pm PDT)
3030
CONIFGS_SCHEDULED_TIME = "0 9 * * *" if composer_env.is_prod_env() else None
@@ -66,7 +66,7 @@
6666
quarantine_task_group = TaskGroup(
6767
group_id="Quarantine", dag=dag, prefix_group_id=False
6868
)
69-
all_tests = []
69+
all_tests: list[task.XpkTask] = []
7070
for mode, image in DOCKER_IMAGES:
7171
for model in MaxTextTrilliumModelConfigs:
7272
# No tpu-recipe for DeepSeek v3
@@ -82,7 +82,11 @@
8282
image = DockerImage.MAXTEXT_TPU_JAX_STABLE
8383

8484
base_run_model_cmds = [
85-
f"python3 -m benchmarks.benchmark_runner on-device --base_output_directory={BASE_OUTPUT_DIRECTORY} --model_name={model.value} --libtpu_type=maxtext-docker --num_steps=15",
85+
"bash preflight.sh",
86+
f"python3 -m benchmarks.benchmark_runner on-device "
87+
f"--base_output_directory={BASE_OUTPUT_DIRECTORY} "
88+
f"--model_name={model.value} --libtpu_type=maxtext-docker "
89+
f"--num_steps=15",
8690
]
8791
num_slices = (
8892
[2]
@@ -110,16 +114,16 @@
110114
enable_profile_config=enable_profile_config,
111115
)
112116
)
113-
all_tests += maxtext_sweep_gke_test
117+
all_tests.extend(maxtext_sweep_gke_test)
114118

115119
# Add dependencies between the tests so they are not all launched at once
116120
mlcompass_scheduler = mlcompass.Scheduler()
117-
chain_num = 4
121+
chain_num = 16
118122
prev = all_tests[0].run_with_name_gen_and_quarantine(quarantine_task_group)
119123
mlcompass_scheduler.register(prev)
120124
for i in range(1, len(all_tests)):
121125
curr = all_tests[i].run_with_name_gen_and_quarantine(quarantine_task_group)
122126
mlcompass_scheduler.register(curr)
123127
if i % chain_num != 0:
124-
prev >> curr
128+
_ = prev >> curr
125129
prev = curr

dags/multipod/maxtext_v5e_configs_perf.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from dags.common.model_configs import MaxTextV5eModelConfigs
2525
from dags.multipod.configs import maxtext_sweep_gke_config
2626
from dags.multipod.configs.common import SetupMode
27-
from xlml.apis import metric_config
27+
from xlml.apis import metric_config, task
2828

2929
# Run once a day at 4 am UTC (8 pm PST / 9 pm PDT)
3030
SCHEDULED_TIME = "0 3 * * *" if composer_env.is_prod_env() else None
@@ -53,6 +53,7 @@
5353
quarantine_task_group = TaskGroup(
5454
group_id="Quarantine", dag=dag, prefix_group_id=False
5555
)
56+
all_tests: list[task.XpkTask] = []
5657
for mode, image in DOCKER_IMAGES:
5758
for model in MaxTextV5eModelConfigs:
5859
base_run_model_cmds = [
@@ -78,18 +79,15 @@
7879
sweep_params=QUANTIZATION_SWEEP,
7980
)
8081
)
82+
all_tests.extend(maxtext_sweep_gke_test)
8183

82-
chain_num = 16
83-
prev = maxtext_sweep_gke_test[0].run_with_name_gen_and_quarantine(
84-
quarantine_task_group
85-
)
86-
for i in range(1, len(maxtext_sweep_gke_test)):
87-
curr = maxtext_sweep_gke_test[i].run_with_name_gen_and_quarantine(
88-
quarantine_task_group
89-
)
90-
if i % chain_num != 0:
91-
_ = prev >> curr
92-
prev = curr
84+
chain_num = 16
85+
prev = all_tests[0].run_with_name_gen_and_quarantine(quarantine_task_group)
86+
for i in range(1, len(all_tests)):
87+
curr = all_tests[i].run_with_name_gen_and_quarantine(quarantine_task_group)
88+
if i % chain_num != 0:
89+
_ = prev >> curr
90+
prev = curr
9391

9492

9593
# Run once a day at 10 am UTC (2 am PST / 3 am PDT)

dags/multipod/maxtext_v5p_configs_perf.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from dags.common.model_configs import MaxTextV5pModelConfigs
2525
from dags.multipod.configs import maxtext_sweep_gke_config
2626
from dags.multipod.configs.common import SetupMode
27-
from xlml.apis import metric_config
27+
from xlml.apis import metric_config, task
2828

2929
# Run once a day at 4 am UTC (8 pm PST / 9 pm PDT)
3030
SCHEDULED_TIME = "0 1 * * *" if composer_env.is_prod_env() else None
@@ -53,11 +53,15 @@
5353
quarantine_task_group = TaskGroup(
5454
group_id="Quarantine", dag=dag, prefix_group_id=False
5555
)
56+
all_tests: list[task.XpkTask] = []
5657
for mode, image in DOCKER_IMAGES:
5758
for model in MaxTextV5pModelConfigs:
5859
base_run_model_cmds = [
5960
"bash preflight.sh",
60-
f"python3 -m benchmarks.benchmark_runner on-device --base_output_directory={BASE_OUTPUT_DIRECTORY} --model_name={model.value} --libtpu_type=maxtext-docker --num_steps=15",
61+
f"python3 -m benchmarks.benchmark_runner on-device "
62+
f"--base_output_directory={BASE_OUTPUT_DIRECTORY} "
63+
f"--model_name={model.value} --libtpu_type=maxtext-docker "
64+
f"--num_steps=15",
6165
]
6266
maxtext_sweep_gke_test = (
6367
maxtext_sweep_gke_config.get_maxtext_sweep_gke_config(
@@ -75,15 +79,12 @@
7579
sweep_params=QUANTIZATION_SWEEP,
7680
)
7781
)
82+
all_tests.extend(maxtext_sweep_gke_test)
7883

79-
chain_num = 4
80-
prev = maxtext_sweep_gke_test[0].run_with_name_gen_and_quarantine(
81-
quarantine_task_group
82-
)
83-
for i in range(1, len(maxtext_sweep_gke_test)):
84-
curr = maxtext_sweep_gke_test[i].run_with_name_gen_and_quarantine(
85-
quarantine_task_group
86-
)
87-
if i % chain_num != 0:
88-
prev >> curr
89-
prev = curr
84+
chain_num = 16
85+
prev = all_tests[0].run_with_name_gen_and_quarantine(quarantine_task_group)
86+
for i in range(1, len(all_tests)):
87+
curr = all_tests[i].run_with_name_gen_and_quarantine(quarantine_task_group)
88+
if i % chain_num != 0:
89+
_ = prev >> curr
90+
prev = curr

0 commit comments

Comments
 (0)