Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions dags/common/scheduling_helper/scheduling_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Helper module for scheduling DAGs across clusters."""
import datetime as dt
import enum
from typing import TypeAlias

from xlml.apis.xpk_cluster_config import XpkClusterConfig
from dags.common.vm_resource import TpuVersion, Zone


class DayOfWeek(enum.Enum):
ALL = "*"
WEEK_DAY = "1-5"
WEEKEND = "0,6"


# Mock cluster to group TPU Observability DAGs
TPU_OBS_MOCK_CLUSTER = XpkClusterConfig(
name="tpu-observability-automation-prod",
device_version=TpuVersion.TRILLIUM,
core_count=16,
project="cienet-cmcs",
zone=Zone.US_CENTRAL1_B.value,
)

DagIdToTimeout: TypeAlias = dict[str, dt.timedelta]

REGISTERED_DAGS: dict[str, DagIdToTimeout] = {
TPU_OBS_MOCK_CLUSTER.name: {
"gke_node_pool_label_update": dt.timedelta(minutes=30),
"gke_node_pool_status": dt.timedelta(minutes=30),
"jobset_rollback_ttr": dt.timedelta(minutes=90),
"jobset_ttr_node_pool_resize": dt.timedelta(minutes=90),
"jobset_ttr_pod_delete": dt.timedelta(minutes=90),
"multi-host-availability-rollback": dt.timedelta(minutes=30),
"node_pool_ttr_disk_size": dt.timedelta(minutes=90),
"node_pool_ttr_update_label": dt.timedelta(minutes=90),
"tpu_info_format_validation_dag": dt.timedelta(minutes=30),
"tpu_sdk_monitoring_validation": dt.timedelta(minutes=30),
"jobset_ttr_kill_process": dt.timedelta(minutes=90),
},
}


def get_dag_timeout(dag_id: str) -> dt.timedelta:
"""Searches the registry and returns the specific timeout for a DAG."""
for cluster_dags in REGISTERED_DAGS.values():
if dag_id in cluster_dags:
return cluster_dags[dag_id]
raise ValueError(
f"DAG '{dag_id}' is not registered. Please add it to REGISTERED_DAGS."
)


class SchedulingHelper:
"""Manages DAG scheduling across different clusters."""

DEFAULT_MARGIN = dt.timedelta(minutes=15)
DEFAULT_ANCHOR = dt.datetime(2000, 1, 1, 8, 0, 0, tzinfo=dt.timezone.utc)

@classmethod
def arrange_schedule_time(
cls,
dag_id: str,
day_of_week: DayOfWeek = DayOfWeek.ALL,
) -> str:
"""Calculates a cron schedule by stacking timeouts and margins."""
anchor = cls.DEFAULT_ANCHOR

for cluster_name, dags in REGISTERED_DAGS.items():
if dag_id not in dags:
continue

offset = dt.timedelta(0)
for current_dag_id, timeout in dags.items():
if current_dag_id == dag_id:
schedule = anchor + offset
return f"{schedule.minute} {schedule.hour} * * {day_of_week.value}"
offset += timeout + cls.DEFAULT_MARGIN

if offset >= dt.timedelta(hours=24):
raise ValueError(
f"Schedule exceeds 24h window at '{dag_id}' in cluster '{cluster_name}'."
)

raise ValueError(
f"DAG '{dag_id}' is not registered. Please add it to REGISTERED_DAGS."
)
11 changes: 9 additions & 2 deletions dags/tpu_observability/jobset_ttr_kill_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@
GCS_CONFIG_PATH,
GCS_JOBSET_CONFIG_PATH,
)
from dags.common.scheduling_helper.scheduling_helper import SchedulingHelper, get_dag_timeout


DAG_ID = "jobset_ttr_kill_process"
DAGRUN_TIMEOUT = get_dag_timeout(DAG_ID)
SCHEDULE = SchedulingHelper.arrange_schedule_time(DAG_ID)


@task
Expand Down Expand Up @@ -70,9 +76,10 @@ def kill_tpu_pod_workload(info: node_pool.Info, pod_name: str) -> None:
# Keyword arguments are generated dynamically at runtime (pylint does not
# know this signature).
with models.DAG( # pylint: disable=unexpected-keyword-arg
dag_id="jobset_ttr_kill_process",
dag_id=DAG_ID,
start_date=datetime.datetime(2025, 8, 10),
schedule="0 15 * * *" if composer_env.is_prod_env() else None,
schedule=SCHEDULE if composer_env.is_prod_env() else None,
dagrun_timeout=DAGRUN_TIMEOUT,
catchup=False,
tags=[
"cloud-ml-auto-solutions",
Expand Down
9 changes: 7 additions & 2 deletions dags/tpu_observability/jobset_ttr_node_pool_resize.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,21 @@
GCS_CONFIG_PATH,
GCS_JOBSET_CONFIG_PATH,
)
from dags.common.scheduling_helper.scheduling_helper import SchedulingHelper, get_dag_timeout


DAG_ID = "jobset_ttr_node_pool_resize"
DAGRUN_TIMEOUT = get_dag_timeout(DAG_ID)
SCHEDULE = SchedulingHelper.arrange_schedule_time(DAG_ID)
_DISK_SIZE_INCREMENT = 100

# Keyword arguments are generated dynamically at runtime (pylint does not
# know this signature).
with models.DAG( # pylint: disable=unexpected-keyword-arg
dag_id="jobset_ttr_node_pool_resize",
dag_id=DAG_ID,
start_date=datetime.datetime(2026, 1, 27),
schedule="30 17 * * *" if composer_env.is_prod_env() else None,
schedule=SCHEDULE if composer_env.is_prod_env() else None,
dagrun_timeout=DAGRUN_TIMEOUT,
catchup=False,
tags=[
"cloud-ml-auto-solutions",
Expand Down
11 changes: 9 additions & 2 deletions dags/tpu_observability/jobset_ttr_pod_delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,20 @@
GCS_CONFIG_PATH,
GCS_JOBSET_CONFIG_PATH,
)
from dags.common.scheduling_helper.scheduling_helper import SchedulingHelper, get_dag_timeout


DAG_ID = "jobset_ttr_pod_delete"
DAGRUN_TIMEOUT = get_dag_timeout(DAG_ID)
SCHEDULE = SchedulingHelper.arrange_schedule_time(DAG_ID)

# Keyword arguments are generated dynamically at runtime (pylint does not
# know this signature).
with models.DAG( # pylint: disable=unexpected-keyword-arg
dag_id="jobset_ttr_pod_delete",
dag_id=DAG_ID,
start_date=datetime.datetime(2026, 1, 8),
schedule="0 19 * * *" if composer_env.is_prod_env() else None,
schedule=SCHEDULE if composer_env.is_prod_env() else None,
dagrun_timeout=DAGRUN_TIMEOUT,
catchup=False,
tags=[
"cloud-ml-auto-solutions",
Expand Down
11 changes: 9 additions & 2 deletions dags/tpu_observability/jobset_ttr_rollback.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,20 @@
GCS_CONFIG_PATH,
GCS_JOBSET_CONFIG_PATH,
)
from dags.common.scheduling_helper.scheduling_helper import SchedulingHelper, get_dag_timeout


DAG_ID = "jobset_rollback_ttr"
DAGRUN_TIMEOUT = get_dag_timeout(DAG_ID)
SCHEDULE = SchedulingHelper.arrange_schedule_time(DAG_ID)

# Keyword arguments are generated dynamically at runtime (pylint does not
# know this signature).
with models.DAG( # pylint: disable=unexpected-keyword-arg
dag_id="jobset_rollback_ttr",
dag_id=DAG_ID,
start_date=datetime.datetime(2025, 8, 10),
schedule="30 22 * * *" if composer_env.is_prod_env() else None,
schedule=SCHEDULE if composer_env.is_prod_env() else None,
dagrun_timeout=DAGRUN_TIMEOUT,
catchup=False,
tags=[
"cloud-ml-auto-solutions",
Expand Down
10 changes: 8 additions & 2 deletions dags/tpu_observability/multi_host_nodepool_rollback_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,20 @@
from dags.common import test_owner
from dags.tpu_observability.configs.common import MachineConfigMap, GCS_CONFIG_PATH
from dags.tpu_observability.utils import node_pool_util as node_pool
from dags.common.scheduling_helper.scheduling_helper import SchedulingHelper, get_dag_timeout


DAG_ID = "multi-host-availability-rollback"
DAGRUN_TIMEOUT = get_dag_timeout(DAG_ID)
SCHEDULE = SchedulingHelper.arrange_schedule_time(DAG_ID)

# Keyword arguments are generated dynamically at runtime (pylint does not
# know this signature).
with models.DAG( # pylint: disable=unexpected-keyword-arg
dag_id="multi-host-availability-rollback",
dag_id=DAG_ID,
start_date=datetime.datetime(2025, 8, 10),
schedule="30 19 * * *" if composer_env.is_prod_env() else None,
schedule=SCHEDULE if composer_env.is_prod_env() else None,
dagrun_timeout=DAGRUN_TIMEOUT,
catchup=False,
tags=[
"cloud-ml-auto-solutions",
Expand Down
10 changes: 8 additions & 2 deletions dags/tpu_observability/node_pool_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,20 @@
from dags.common import test_owner
from dags.tpu_observability.configs.common import MachineConfigMap, GCS_CONFIG_PATH
from dags.tpu_observability.utils import node_pool_util as node_pool
from dags.common.scheduling_helper.scheduling_helper import SchedulingHelper, get_dag_timeout


DAG_ID = "gke_node_pool_status"
DAGRUN_TIMEOUT = get_dag_timeout(DAG_ID)
SCHEDULE = SchedulingHelper.arrange_schedule_time(DAG_ID)

# Keyword arguments are generated dynamically at runtime (pylint does not
# know this signature).
with models.DAG( # pylint: disable=unexpected-keyword-arg
dag_id="gke_node_pool_status",
dag_id=DAG_ID,
start_date=datetime.datetime(2025, 8, 1),
schedule="0 18 * * *" if composer_env.is_prod_env() else None,
schedule=SCHEDULE if composer_env.is_prod_env() else None,
dagrun_timeout=DAGRUN_TIMEOUT,
catchup=False,
tags=["gke", "tpu-observability", "node-pool-status", "TPU", "v6e-16"],
description=(
Expand Down
11 changes: 8 additions & 3 deletions dags/tpu_observability/node_pool_ttr_disk_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,19 @@
from dags import composer_env
from dags.tpu_observability.configs.common import MachineConfigMap, GCS_CONFIG_PATH
from dags.tpu_observability.utils import node_pool_util as node_pool
from dags.common.scheduling_helper.scheduling_helper import SchedulingHelper, get_dag_timeout

_DISK_SIZE_INCREMENT = 50

DAG_ID = "node_pool_ttr_disk_size"
DAGRUN_TIMEOUT = get_dag_timeout(DAG_ID)
SCHEDULE = SchedulingHelper.arrange_schedule_time(DAG_ID)
_DISK_SIZE_INCREMENT = 50

with models.DAG(
dag_id="node_pool_ttr_disk_size",
dag_id=DAG_ID,
start_date=datetime.datetime(2025, 6, 26),
schedule="0 21 * * *" if composer_env.is_prod_env() else None,
schedule=SCHEDULE if composer_env.is_prod_env() else None,
dagrun_timeout=DAGRUN_TIMEOUT,
catchup=False,
tags=[
"gke",
Expand Down
10 changes: 8 additions & 2 deletions dags/tpu_observability/node_pool_ttr_update_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,18 @@
from dags import composer_env
from dags.tpu_observability.configs.common import MachineConfigMap, GCS_CONFIG_PATH
from dags.tpu_observability.utils import node_pool_util as node_pool
from dags.common.scheduling_helper.scheduling_helper import SchedulingHelper, get_dag_timeout


DAG_ID = "node_pool_ttr_update_label"
DAGRUN_TIMEOUT = get_dag_timeout(DAG_ID)
SCHEDULE = SchedulingHelper.arrange_schedule_time(DAG_ID)

with models.DAG(
dag_id="node_pool_ttr_update_label",
dag_id=DAG_ID,
start_date=datetime.datetime(2025, 9, 30),
schedule="30 21 * * *" if composer_env.is_prod_env() else None,
schedule=SCHEDULE if composer_env.is_prod_env() else None,
dagrun_timeout=DAGRUN_TIMEOUT,
catchup=False,
tags=[
"gke",
Expand Down
11 changes: 9 additions & 2 deletions dags/tpu_observability/tpu_info_format_validation_dags.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@
from dags.tpu_observability.utils import subprocess_util as subprocess
from dags.tpu_observability.utils import tpu_info_util as tpu_info
from dags.tpu_observability.utils.jobset_util import Workload
from dags.common.scheduling_helper.scheduling_helper import SchedulingHelper, get_dag_timeout


DAG_ID = "tpu_info_format_validation_dag"
DAGRUN_TIMEOUT = get_dag_timeout(DAG_ID)
SCHEDULE = SchedulingHelper.arrange_schedule_time(DAG_ID)


@task
Expand Down Expand Up @@ -289,10 +295,11 @@ def validate_latency_table(tpu_info_output: list[tpu_info.Table]):
# Keyword arguments are generated dynamically at runtime (pylint does not
# know this signature).
with models.DAG( # pylint: disable=unexpected-keyword-arg
dag_id="tpu_info_format_validation_dag",
dag_id=DAG_ID,
start_date=datetime.datetime(2025, 8, 15),
default_args={"retries": 0},
schedule="0 20 * * *" if composer_env.is_prod_env() else None,
schedule=SCHEDULE if composer_env.is_prod_env() else None,
dagrun_timeout=DAGRUN_TIMEOUT,
catchup=False,
tags=["gke", "tpu-observability", "tpu-info", "TPU", "v6e-16"],
description=(
Expand Down
11 changes: 9 additions & 2 deletions dags/tpu_observability/tpu_sdk_monitoring_validation_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@
GCS_CONFIG_PATH,
GCS_JOBSET_CONFIG_PATH,
)
from dags.common.scheduling_helper.scheduling_helper import SchedulingHelper, get_dag_timeout


DAG_ID = "tpu_sdk_monitoring_validation"
DAGRUN_TIMEOUT = get_dag_timeout(DAG_ID)
SCHEDULE = SchedulingHelper.arrange_schedule_time(DAG_ID)


@task
Expand Down Expand Up @@ -79,9 +85,10 @@ def validate_monitoring_sdk(info: node_pool.Info, pod_name: str) -> None:


with models.DAG(
dag_id="tpu_sdk_monitoring_validation",
dag_id=DAG_ID,
start_date=datetime.datetime(2026, 1, 13),
schedule="0 22 * * *" if composer_env.is_prod_env() else None,
schedule=SCHEDULE if composer_env.is_prod_env() else None,
dagrun_timeout=DAGRUN_TIMEOUT,
catchup=False,
tags=[
"cloud-ml-auto-solutions",
Expand Down
11 changes: 9 additions & 2 deletions dags/tpu_observability/update_node_pool_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,20 @@
from dags.common import test_owner
from dags.tpu_observability.configs.common import MachineConfigMap, GCS_CONFIG_PATH
from dags.tpu_observability.utils import node_pool_util as node_pool
from dags.common.scheduling_helper.scheduling_helper import SchedulingHelper, get_dag_timeout


DAG_ID = "gke_node_pool_label_update"
DAGRUN_TIMEOUT = get_dag_timeout(DAG_ID)
SCHEDULE = SchedulingHelper.arrange_schedule_time(DAG_ID)

# Keyword arguments are generated dynamically at runtime (pylint does not
# know this signature).
with models.DAG( # pylint: disable=unexpected-keyword-arg
dag_id="gke_node_pool_label_update",
dag_id=DAG_ID,
start_date=datetime.datetime(2025, 8, 1),
schedule="30 20 * * *" if composer_env.is_prod_env() else None,
schedule=SCHEDULE if composer_env.is_prod_env() else None,
dagrun_timeout=DAGRUN_TIMEOUT,
catchup=False,
tags=["gke", "tpu-observability", "node-pool-status", "TPU", "v6e-16"],
description=(
Expand Down
Loading