diff --git a/dags/tpu_observability/jobset_ttr_node_pool_resize.py b/dags/tpu_observability/jobset_ttr_node_pool_resize.py new file mode 100644 index 000000000..b8d927e82 --- /dev/null +++ b/dags/tpu_observability/jobset_ttr_node_pool_resize.py @@ -0,0 +1,166 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A DAG to test JobSet time-to-recover metric using a node pool disk resize.""" + +import datetime + +from airflow import models +from airflow.models.baseoperator import chain +from airflow.utils.trigger_rule import TriggerRule +from airflow.utils.task_group import TaskGroup + +from dags import composer_env +from dags.tpu_observability.utils import jobset_util as jobset +from dags.tpu_observability.utils import node_pool_util as node_pool +from dags.tpu_observability.utils.jobset_util import JobSet, Workload +from dags.tpu_observability.configs.common import MachineConfigMap, GCS_CONFIG_PATH + +_DISK_SIZE_INCREMENT = 100 + +# Keyword arguments are generated dynamically at runtime (pylint does not +# know this signature). +with models.DAG( # pylint: disable=unexpected-keyword-arg + dag_id="jobset_ttr_node_pool_resize", + start_date=datetime.datetime(2026, 1, 27), + schedule="30 17 * * *" if composer_env.is_prod_env() else None, + catchup=False, + tags=[ + "cloud-ml-auto-solutions", + "jobset", + "time-to-recover", + "tpu-observability", + "node-pool-resize", + "TPU", + "v6e-16", + ], + description=( + "This DAG tests the JobSet time-to-recover metric by triggering a " + "node pool disk resize, then polls the metric to check if it is updated." + ), + doc_md=""" + # JobSet Time-To-Recover (TTR) Test Using Node Pool Disk Resize + + ### Description + This DAG verifies that JobSet can recover when the underlying node pool + undergoes a disruptive update (Disk Resize). It launches a JobSet, + increases the disk size of the node pool, and confirms that the + JobSet controller restarts the workload successfully. + + ### Prerequisites + This test requires an existing cluster to run. + + ### Procedures + First the node-pool is created, a jobset yaml is then launched on the + cluster and given a short period of time to initialize. After this a + node pool disk resize is triggered to interrupt the jobset. A sensor is + finally run which will poll Cloud Monitoring to detect that the jobset + time-to-recover (TTR) metric has been updated, resulting in a success, + or timeout, and fail. + """, +) as dag: + for machine in MachineConfigMap: + config = machine.value + + jobset_config = JobSet( + jobset_name="ttr-res-v6e", + namespace="default", + max_restarts=10, + replicated_job_name="tpu-job-slice", + replicas=1, + backoff_limit=0, + completions=4, + parallelism=4, + tpu_accelerator_type="tpu-v6e-slice", + tpu_topology="4x4", + container_name="jax-tpu-worker", + image="python:3.11", + tpu_cores_per_pod=4, + ) + + # Keyword arguments are generated dynamically at runtime (pylint does not + # know this signature). + with TaskGroup( # pylint: disable=unexpected-keyword-arg + group_id=f"v{config.tpu_version.value}" + ): + cluster_info = node_pool.build_node_pool_info_from_gcs_yaml.override( + task_id="build_node_pool_info_from_gcs_yaml" + )( + gcs_path=GCS_CONFIG_PATH, + dag_name="jobset_ttr_node_pool_resize", + is_prod=composer_env.is_prod_env(), + machine_type=config.machine_version.value, + tpu_topology=config.tpu_topology, + ) + + create_node_pool = node_pool.create.override(task_id="create_node_pool")( + node_pool=cluster_info, + ) + + start_workload = jobset.run_workload.override(task_id="start_workload")( + node_pool=cluster_info, + yaml_config=jobset_config.generate_yaml( + workload_script=Workload.JAX_TPU_BENCHMARK + ), + namespace=jobset_config.namespace, + ) + + ensure_all_pods_running = jobset.wait_for_all_pods_running.override( + task_id="ensure_all_pods_running" + )( + num_pods=(jobset_config.replicas * jobset_config.parallelism), + node_pool=cluster_info, + ) + + node_pool_resize = node_pool.update.override(task_id="node_pool_resize")( + node_pool=cluster_info, + spec=node_pool.NodePoolUpdateSpec.DiskSize( + delta=_DISK_SIZE_INCREMENT + ), + ) + + wait_for_metric_upload = jobset.wait_for_jobset_ttr_to_be_found.override( + task_id="wait_for_jobset_ttr_to_be_found" + )( + node_pool=cluster_info, + jobset_name=jobset_config.jobset_name, + start_time=node_pool_resize, + ) + + cleanup_workload = jobset.end_workload.override( + task_id="cleanup_workload", trigger_rule=TriggerRule.ALL_DONE + )( + node_pool=cluster_info, + jobset_name=jobset_config.jobset_name, + namespace=jobset_config.namespace, + ).as_teardown( + setups=start_workload + ) + + cleanup_node_pool = node_pool.delete.override( + task_id="cleanup_node_pool", trigger_rule=TriggerRule.ALL_DONE + )(node_pool=cluster_info).as_teardown( + setups=create_node_pool, + ) + + chain( + cluster_info, + create_node_pool, + start_workload, + ensure_all_pods_running, + node_pool_resize, + wait_for_metric_upload, + cleanup_workload, + cleanup_node_pool, + ) diff --git a/dags/tpu_observability/utils/jobset_util.py b/dags/tpu_observability/utils/jobset_util.py index b76d2dd48..5a287f8f9 100644 --- a/dags/tpu_observability/utils/jobset_util.py +++ b/dags/tpu_observability/utils/jobset_util.py @@ -594,9 +594,9 @@ def wait_for_jobset_started( @task.sensor(poke_interval=60, timeout=3600, mode="poke") def wait_for_jobset_ttr_to_be_found( - node_pool: node_pool_info, jobset_name: str + node_pool: node_pool_info, jobset_name: str, start_time: TimeUtil = None ) -> bool: - """Polls the jobset time_between_interruptions metric. + """Polls the jobset time-to-recover metric. A sensor task which polls the jobset time_between_interruptions metric every 60 seconds for 60 minutes. 60 minutes is used here since this @@ -606,11 +606,20 @@ def wait_for_jobset_ttr_to_be_found( impractical for the test to run longer. Args: - node_pool: An instance of the Info class that encapsulates - the configuration and metadata of a GKE node pool and workload. - jobset_name: The name of the JobSet. + node_pool (Info): An instance of the Info class containing GKE metadata. + jobset_name (str): The name of the JobSet to monitor. + start_time (TimeUtil, optional): The UTC timestamp to start polling from. + If not provided, defaults to 60 minutes before the current time. + + Returns: + bool: True if the TTR metric is found in Cloud Monitoring, False otherwise. """ now = datetime.datetime.now() + query_start = ( + start_time + if start_time + else TimeUtil.from_datetime(now - datetime.timedelta(minutes=60)) + ) time_series = query_time_series( project_id=node_pool.project_id, @@ -619,7 +628,7 @@ def wait_for_jobset_ttr_to_be_found( f'resource.labels.cluster_name="{node_pool.cluster_name}" ' f'resource.labels.entity_name="{jobset_name}"' ), - start_time=TimeUtil.from_datetime(now - datetime.timedelta(minutes=60)), + start_time=query_start, end_time=TimeUtil.from_datetime(now), )