Skip to content

Commit e5e5fba

Browse files
authored
feat : Add a scheduling helper for tpu_observability DAGs - part 1 (#1181)
This change introduces the `SchedulingHelper` utility and the `get_dag_timeout` function to manage and calculate non-overlapping execution schedules for Airflow DAGs. This implementation ensures resource safety and configuration consistency across TPU clusters through automated time-slot allocation.
1 parent dda2309 commit e5e5fba

12 files changed

+194
-23
lines changed
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Helper module for scheduling DAGs across clusters."""
16+
import datetime as dt
17+
import enum
18+
from typing import TypeAlias
19+
20+
from xlml.apis.xpk_cluster_config import XpkClusterConfig
21+
from dags.common.vm_resource import TpuVersion, Zone
22+
23+
24+
class DayOfWeek(enum.Enum):
25+
ALL = "*"
26+
WEEK_DAY = "1-5"
27+
WEEKEND = "0,6"
28+
29+
30+
# Mock cluster to group TPU Observability DAGs
31+
TPU_OBS_MOCK_CLUSTER = XpkClusterConfig(
32+
name="tpu-observability-automation-prod",
33+
device_version=TpuVersion.TRILLIUM,
34+
core_count=16,
35+
project="cienet-cmcs",
36+
zone=Zone.US_CENTRAL1_B.value,
37+
)
38+
39+
DagIdToTimeout: TypeAlias = dict[str, dt.timedelta]
40+
41+
REGISTERED_DAGS: dict[str, DagIdToTimeout] = {
42+
TPU_OBS_MOCK_CLUSTER.name: {
43+
"gke_node_pool_label_update": dt.timedelta(minutes=30),
44+
"gke_node_pool_status": dt.timedelta(minutes=30),
45+
"jobset_rollback_ttr": dt.timedelta(minutes=90),
46+
"jobset_ttr_node_pool_resize": dt.timedelta(minutes=90),
47+
"jobset_ttr_pod_delete": dt.timedelta(minutes=90),
48+
"multi-host-availability-rollback": dt.timedelta(minutes=30),
49+
"node_pool_ttr_disk_size": dt.timedelta(minutes=90),
50+
"node_pool_ttr_update_label": dt.timedelta(minutes=90),
51+
"tpu_info_format_validation_dag": dt.timedelta(minutes=30),
52+
"tpu_sdk_monitoring_validation": dt.timedelta(minutes=30),
53+
"jobset_ttr_kill_process": dt.timedelta(minutes=90),
54+
},
55+
}
56+
57+
58+
def get_dag_timeout(dag_id: str) -> dt.timedelta:
59+
"""Searches the registry and returns the specific timeout for a DAG."""
60+
for cluster_dags in REGISTERED_DAGS.values():
61+
if dag_id in cluster_dags:
62+
return cluster_dags[dag_id]
63+
raise ValueError(
64+
f"DAG '{dag_id}' is not registered. Please add it to REGISTERED_DAGS."
65+
)
66+
67+
68+
class SchedulingHelper:
69+
"""Manages DAG scheduling across different clusters."""
70+
71+
DEFAULT_MARGIN = dt.timedelta(minutes=15)
72+
DEFAULT_ANCHOR = dt.datetime(2000, 1, 1, 8, 0, 0, tzinfo=dt.timezone.utc)
73+
74+
@classmethod
75+
def arrange_schedule_time(
76+
cls,
77+
dag_id: str,
78+
day_of_week: DayOfWeek = DayOfWeek.ALL,
79+
) -> str:
80+
"""Calculates a cron schedule by stacking timeouts and margins."""
81+
anchor = cls.DEFAULT_ANCHOR
82+
83+
for cluster_name, dags in REGISTERED_DAGS.items():
84+
if dag_id not in dags:
85+
continue
86+
87+
offset = dt.timedelta(0)
88+
for current_dag_id, timeout in dags.items():
89+
if current_dag_id == dag_id:
90+
schedule = anchor + offset
91+
return f"{schedule.minute} {schedule.hour} * * {day_of_week.value}"
92+
offset += timeout + cls.DEFAULT_MARGIN
93+
94+
if offset >= dt.timedelta(hours=24):
95+
raise ValueError(
96+
f"Schedule exceeds 24h window at '{dag_id}' in cluster '{cluster_name}'."
97+
)
98+
99+
raise ValueError(
100+
f"DAG '{dag_id}' is not registered. Please add it to REGISTERED_DAGS."
101+
)

dags/tpu_observability/jobset_ttr_kill_process.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@
3939
GCS_CONFIG_PATH,
4040
GCS_JOBSET_CONFIG_PATH,
4141
)
42+
from dags.common.scheduling_helper.scheduling_helper import SchedulingHelper, get_dag_timeout
43+
44+
45+
DAG_ID = "jobset_ttr_kill_process"
46+
DAGRUN_TIMEOUT = get_dag_timeout(DAG_ID)
47+
SCHEDULE = SchedulingHelper.arrange_schedule_time(DAG_ID)
4248

4349

4450
@task
@@ -70,9 +76,10 @@ def kill_tpu_pod_workload(info: node_pool.Info, pod_name: str) -> None:
7076
# Keyword arguments are generated dynamically at runtime (pylint does not
7177
# know this signature).
7278
with models.DAG( # pylint: disable=unexpected-keyword-arg
73-
dag_id="jobset_ttr_kill_process",
79+
dag_id=DAG_ID,
7480
start_date=datetime.datetime(2025, 8, 10),
75-
schedule="0 15 * * *" if composer_env.is_prod_env() else None,
81+
schedule=SCHEDULE if composer_env.is_prod_env() else None,
82+
dagrun_timeout=DAGRUN_TIMEOUT,
7683
catchup=False,
7784
tags=[
7885
"cloud-ml-auto-solutions",

dags/tpu_observability/jobset_ttr_node_pool_resize.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,21 @@
3030
GCS_CONFIG_PATH,
3131
GCS_JOBSET_CONFIG_PATH,
3232
)
33+
from dags.common.scheduling_helper.scheduling_helper import SchedulingHelper, get_dag_timeout
3334

3435

36+
DAG_ID = "jobset_ttr_node_pool_resize"
37+
DAGRUN_TIMEOUT = get_dag_timeout(DAG_ID)
38+
SCHEDULE = SchedulingHelper.arrange_schedule_time(DAG_ID)
3539
_DISK_SIZE_INCREMENT = 100
3640

3741
# Keyword arguments are generated dynamically at runtime (pylint does not
3842
# know this signature).
3943
with models.DAG( # pylint: disable=unexpected-keyword-arg
40-
dag_id="jobset_ttr_node_pool_resize",
44+
dag_id=DAG_ID,
4145
start_date=datetime.datetime(2026, 1, 27),
42-
schedule="30 17 * * *" if composer_env.is_prod_env() else None,
46+
schedule=SCHEDULE if composer_env.is_prod_env() else None,
47+
dagrun_timeout=DAGRUN_TIMEOUT,
4348
catchup=False,
4449
tags=[
4550
"cloud-ml-auto-solutions",

dags/tpu_observability/jobset_ttr_pod_delete.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,20 @@
3030
GCS_CONFIG_PATH,
3131
GCS_JOBSET_CONFIG_PATH,
3232
)
33+
from dags.common.scheduling_helper.scheduling_helper import SchedulingHelper, get_dag_timeout
34+
35+
36+
DAG_ID = "jobset_ttr_pod_delete"
37+
DAGRUN_TIMEOUT = get_dag_timeout(DAG_ID)
38+
SCHEDULE = SchedulingHelper.arrange_schedule_time(DAG_ID)
3339

3440
# Keyword arguments are generated dynamically at runtime (pylint does not
3541
# know this signature).
3642
with models.DAG( # pylint: disable=unexpected-keyword-arg
37-
dag_id="jobset_ttr_pod_delete",
43+
dag_id=DAG_ID,
3844
start_date=datetime.datetime(2026, 1, 8),
39-
schedule="0 19 * * *" if composer_env.is_prod_env() else None,
45+
schedule=SCHEDULE if composer_env.is_prod_env() else None,
46+
dagrun_timeout=DAGRUN_TIMEOUT,
4047
catchup=False,
4148
tags=[
4249
"cloud-ml-auto-solutions",

dags/tpu_observability/jobset_ttr_rollback.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,20 @@
3030
GCS_CONFIG_PATH,
3131
GCS_JOBSET_CONFIG_PATH,
3232
)
33+
from dags.common.scheduling_helper.scheduling_helper import SchedulingHelper, get_dag_timeout
34+
35+
36+
DAG_ID = "jobset_rollback_ttr"
37+
DAGRUN_TIMEOUT = get_dag_timeout(DAG_ID)
38+
SCHEDULE = SchedulingHelper.arrange_schedule_time(DAG_ID)
3339

3440
# Keyword arguments are generated dynamically at runtime (pylint does not
3541
# know this signature).
3642
with models.DAG( # pylint: disable=unexpected-keyword-arg
37-
dag_id="jobset_rollback_ttr",
43+
dag_id=DAG_ID,
3844
start_date=datetime.datetime(2025, 8, 10),
39-
schedule="30 22 * * *" if composer_env.is_prod_env() else None,
45+
schedule=SCHEDULE if composer_env.is_prod_env() else None,
46+
dagrun_timeout=DAGRUN_TIMEOUT,
4047
catchup=False,
4148
tags=[
4249
"cloud-ml-auto-solutions",

dags/tpu_observability/multi_host_nodepool_rollback_dag.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,20 @@
2828
from dags.common import test_owner
2929
from dags.tpu_observability.configs.common import MachineConfigMap, GCS_CONFIG_PATH
3030
from dags.tpu_observability.utils import node_pool_util as node_pool
31+
from dags.common.scheduling_helper.scheduling_helper import SchedulingHelper, get_dag_timeout
3132

3233

34+
DAG_ID = "multi-host-availability-rollback"
35+
DAGRUN_TIMEOUT = get_dag_timeout(DAG_ID)
36+
SCHEDULE = SchedulingHelper.arrange_schedule_time(DAG_ID)
37+
3338
# Keyword arguments are generated dynamically at runtime (pylint does not
3439
# know this signature).
3540
with models.DAG( # pylint: disable=unexpected-keyword-arg
36-
dag_id="multi-host-availability-rollback",
41+
dag_id=DAG_ID,
3742
start_date=datetime.datetime(2025, 8, 10),
38-
schedule="30 19 * * *" if composer_env.is_prod_env() else None,
43+
schedule=SCHEDULE if composer_env.is_prod_env() else None,
44+
dagrun_timeout=DAGRUN_TIMEOUT,
3945
catchup=False,
4046
tags=[
4147
"cloud-ml-auto-solutions",

dags/tpu_observability/node_pool_status.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,20 @@
2626
from dags.common import test_owner
2727
from dags.tpu_observability.configs.common import MachineConfigMap, GCS_CONFIG_PATH
2828
from dags.tpu_observability.utils import node_pool_util as node_pool
29+
from dags.common.scheduling_helper.scheduling_helper import SchedulingHelper, get_dag_timeout
2930

3031

32+
DAG_ID = "gke_node_pool_status"
33+
DAGRUN_TIMEOUT = get_dag_timeout(DAG_ID)
34+
SCHEDULE = SchedulingHelper.arrange_schedule_time(DAG_ID)
35+
3136
# Keyword arguments are generated dynamically at runtime (pylint does not
3237
# know this signature).
3338
with models.DAG( # pylint: disable=unexpected-keyword-arg
34-
dag_id="gke_node_pool_status",
39+
dag_id=DAG_ID,
3540
start_date=datetime.datetime(2025, 8, 1),
36-
schedule="0 18 * * *" if composer_env.is_prod_env() else None,
41+
schedule=SCHEDULE if composer_env.is_prod_env() else None,
42+
dagrun_timeout=DAGRUN_TIMEOUT,
3743
catchup=False,
3844
tags=["gke", "tpu-observability", "node-pool-status", "TPU", "v6e-16"],
3945
description=(

dags/tpu_observability/node_pool_ttr_disk_size.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,19 @@
2525
from dags import composer_env
2626
from dags.tpu_observability.configs.common import MachineConfigMap, GCS_CONFIG_PATH
2727
from dags.tpu_observability.utils import node_pool_util as node_pool
28+
from dags.common.scheduling_helper.scheduling_helper import SchedulingHelper, get_dag_timeout
2829

29-
_DISK_SIZE_INCREMENT = 50
3030

31+
DAG_ID = "node_pool_ttr_disk_size"
32+
DAGRUN_TIMEOUT = get_dag_timeout(DAG_ID)
33+
SCHEDULE = SchedulingHelper.arrange_schedule_time(DAG_ID)
34+
_DISK_SIZE_INCREMENT = 50
3135

3236
with models.DAG(
33-
dag_id="node_pool_ttr_disk_size",
37+
dag_id=DAG_ID,
3438
start_date=datetime.datetime(2025, 6, 26),
35-
schedule="0 21 * * *" if composer_env.is_prod_env() else None,
39+
schedule=SCHEDULE if composer_env.is_prod_env() else None,
40+
dagrun_timeout=DAGRUN_TIMEOUT,
3641
catchup=False,
3742
tags=[
3843
"gke",

dags/tpu_observability/node_pool_ttr_update_label.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,18 @@
2424
from dags import composer_env
2525
from dags.tpu_observability.configs.common import MachineConfigMap, GCS_CONFIG_PATH
2626
from dags.tpu_observability.utils import node_pool_util as node_pool
27+
from dags.common.scheduling_helper.scheduling_helper import SchedulingHelper, get_dag_timeout
2728

2829

30+
DAG_ID = "node_pool_ttr_update_label"
31+
DAGRUN_TIMEOUT = get_dag_timeout(DAG_ID)
32+
SCHEDULE = SchedulingHelper.arrange_schedule_time(DAG_ID)
33+
2934
with models.DAG(
30-
dag_id="node_pool_ttr_update_label",
35+
dag_id=DAG_ID,
3136
start_date=datetime.datetime(2025, 9, 30),
32-
schedule="30 21 * * *" if composer_env.is_prod_env() else None,
37+
schedule=SCHEDULE if composer_env.is_prod_env() else None,
38+
dagrun_timeout=DAGRUN_TIMEOUT,
3339
catchup=False,
3440
tags=[
3541
"gke",

dags/tpu_observability/tpu_info_format_validation_dags.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@
4444
from dags.tpu_observability.utils import subprocess_util as subprocess
4545
from dags.tpu_observability.utils import tpu_info_util as tpu_info
4646
from dags.tpu_observability.utils.jobset_util import Workload
47+
from dags.common.scheduling_helper.scheduling_helper import SchedulingHelper, get_dag_timeout
48+
49+
50+
DAG_ID = "tpu_info_format_validation_dag"
51+
DAGRUN_TIMEOUT = get_dag_timeout(DAG_ID)
52+
SCHEDULE = SchedulingHelper.arrange_schedule_time(DAG_ID)
4753

4854

4955
@task
@@ -289,10 +295,11 @@ def validate_latency_table(tpu_info_output: list[tpu_info.Table]):
289295
# Keyword arguments are generated dynamically at runtime (pylint does not
290296
# know this signature).
291297
with models.DAG( # pylint: disable=unexpected-keyword-arg
292-
dag_id="tpu_info_format_validation_dag",
298+
dag_id=DAG_ID,
293299
start_date=datetime.datetime(2025, 8, 15),
294300
default_args={"retries": 0},
295-
schedule="0 20 * * *" if composer_env.is_prod_env() else None,
301+
schedule=SCHEDULE if composer_env.is_prod_env() else None,
302+
dagrun_timeout=DAGRUN_TIMEOUT,
296303
catchup=False,
297304
tags=["gke", "tpu-observability", "tpu-info", "TPU", "v6e-16"],
298305
description=(

0 commit comments

Comments
 (0)