Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 166 additions & 0 deletions dags/tpu_observability/jobset_ttr_node_pool_resize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A DAG to test JobSet time-to-recover metric using a node pool disk resize."""

import datetime

from airflow import models
from airflow.models.baseoperator import chain
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.task_group import TaskGroup

from dags import composer_env
from dags.tpu_observability.utils import jobset_util as jobset
from dags.tpu_observability.utils import node_pool_util as node_pool
from dags.tpu_observability.utils.jobset_util import JobSet, Workload
from dags.tpu_observability.configs.common import MachineConfigMap, GCS_CONFIG_PATH

_DISK_SIZE_INCREMENT = 100

# Keyword arguments are generated dynamically at runtime (pylint does not
# know this signature).
with models.DAG( # pylint: disable=unexpected-keyword-arg
dag_id="jobset_ttr_node_pool_resize",
start_date=datetime.datetime(2026, 1, 27),
schedule="30 17 * * *" if composer_env.is_prod_env() else None,
catchup=False,
tags=[
"cloud-ml-auto-solutions",
"jobset",
"time-to-recover",
"tpu-observability",
"node-pool-resize",
"TPU",
"v6e-16",
],
description=(
"This DAG tests the JobSet time-to-recover metric by triggering a "
"node pool disk resize, then polls the metric to check if it is updated."
),
doc_md="""
# JobSet Time-To-Recover (TTR) Test Using Node Pool Disk Resize

### Description
This DAG verifies that JobSet can recover when the underlying node pool
undergoes a disruptive update (Disk Resize). It launches a JobSet,
increases the disk size of the node pool, and confirms that the
JobSet controller restarts the workload successfully.

### Prerequisites
This test requires an existing cluster to run.

### Procedures
First the node-pool is created, a jobset yaml is then launched on the
cluster and given a short period of time to initialize. After this a
node pool disk resize is triggered to interrupt the jobset. A sensor is
finally run which will poll Cloud Monitoring to detect that the jobset
time-to-recover (TTR) metric has been updated, resulting in a success,
or timeout, and fail.
""",
) as dag:
for machine in MachineConfigMap:
config = machine.value

jobset_config = JobSet(
jobset_name="ttr-res-v6e",
namespace="default",
max_restarts=10,
replicated_job_name="tpu-job-slice",
replicas=1,
backoff_limit=0,
completions=4,
parallelism=4,
tpu_accelerator_type="tpu-v6e-slice",
tpu_topology="4x4",
container_name="jax-tpu-worker",
image="python:3.11",
tpu_cores_per_pod=4,
)

# Keyword arguments are generated dynamically at runtime (pylint does not
# know this signature).
with TaskGroup( # pylint: disable=unexpected-keyword-arg
group_id=f"v{config.tpu_version.value}"
):
cluster_info = node_pool.build_node_pool_info_from_gcs_yaml.override(
task_id="build_node_pool_info_from_gcs_yaml"
)(
gcs_path=GCS_CONFIG_PATH,
dag_name="jobset_ttr_node_pool_resize",
is_prod=composer_env.is_prod_env(),
machine_type=config.machine_version.value,
tpu_topology=config.tpu_topology,
)

create_node_pool = node_pool.create.override(task_id="create_node_pool")(
node_pool=cluster_info,
)

start_workload = jobset.run_workload.override(task_id="start_workload")(
node_pool=cluster_info,
yaml_config=jobset_config.generate_yaml(
workload_script=Workload.JAX_TPU_BENCHMARK
),
namespace=jobset_config.namespace,
)

ensure_all_pods_running = jobset.wait_for_all_pods_running.override(
task_id="ensure_all_pods_running"
)(
num_pods=(jobset_config.replicas * jobset_config.parallelism),
node_pool=cluster_info,
)

node_pool_resize = node_pool.update.override(task_id="node_pool_resize")(
node_pool=cluster_info,
spec=node_pool.NodePoolUpdateSpec.DiskSize(
delta=_DISK_SIZE_INCREMENT
),
)

wait_for_metric_upload = jobset.wait_for_jobset_ttr_to_be_found.override(
task_id="wait_for_jobset_ttr_to_be_found"
)(
node_pool=cluster_info,
jobset_name=jobset_config.jobset_name,
start_time=node_pool_resize,
)

cleanup_workload = jobset.end_workload.override(
task_id="cleanup_workload", trigger_rule=TriggerRule.ALL_DONE
)(
node_pool=cluster_info,
jobset_name=jobset_config.jobset_name,
namespace=jobset_config.namespace,
).as_teardown(
setups=start_workload
)

cleanup_node_pool = node_pool.delete.override(
task_id="cleanup_node_pool", trigger_rule=TriggerRule.ALL_DONE
)(node_pool=cluster_info).as_teardown(
setups=create_node_pool,
)

chain(
cluster_info,
create_node_pool,
start_workload,
ensure_all_pods_running,
node_pool_resize,
wait_for_metric_upload,
cleanup_workload,
cleanup_node_pool,
)
21 changes: 15 additions & 6 deletions dags/tpu_observability/utils/jobset_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,9 +594,9 @@ def wait_for_jobset_started(

@task.sensor(poke_interval=60, timeout=3600, mode="poke")
def wait_for_jobset_ttr_to_be_found(
node_pool: node_pool_info, jobset_name: str
node_pool: node_pool_info, jobset_name: str, start_time: TimeUtil = None
) -> bool:
"""Polls the jobset time_between_interruptions metric.
"""Polls the jobset time-to-recover metric.

A sensor task which polls the jobset time_between_interruptions metric
every 60 seconds for 60 minutes. 60 minutes is used here since this
Expand All @@ -606,11 +606,20 @@ def wait_for_jobset_ttr_to_be_found(
impractical for the test to run longer.

Args:
node_pool: An instance of the Info class that encapsulates
the configuration and metadata of a GKE node pool and workload.
jobset_name: The name of the JobSet.
node_pool (Info): An instance of the Info class containing GKE metadata.
jobset_name (str): The name of the JobSet to monitor.
start_time (TimeUtil, optional): The UTC timestamp to start polling from.
If not provided, defaults to 60 minutes before the current time.

Returns:
bool: True if the TTR metric is found in Cloud Monitoring, False otherwise.
"""
now = datetime.datetime.now()
query_start = (
start_time
if start_time
else TimeUtil.from_datetime(now - datetime.timedelta(minutes=60))
)

time_series = query_time_series(
project_id=node_pool.project_id,
Expand All @@ -619,7 +628,7 @@ def wait_for_jobset_ttr_to_be_found(
f'resource.labels.cluster_name="{node_pool.cluster_name}" '
f'resource.labels.entity_name="{jobset_name}"'
),
start_time=TimeUtil.from_datetime(now - datetime.timedelta(minutes=60)),
start_time=query_start,
end_time=TimeUtil.from_datetime(now),
)

Expand Down
Loading