forked from GoogleCloudPlatform/ml-auto-solutions
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathjobset_ttr_node_pool_resize.py
More file actions
163 lines (143 loc) · 5.56 KB
/
jobset_ttr_node_pool_resize.py
File metadata and controls
163 lines (143 loc) · 5.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A DAG to test JobSet time-to-recover metric using a node pool disk resize."""
import datetime
from airflow import models
from airflow.models.baseoperator import chain
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.task_group import TaskGroup
from dags import composer_env
from dags.tpu_observability.utils import jobset_util as jobset
from dags.tpu_observability.utils import node_pool_util as node_pool
from dags.tpu_observability.utils.jobset_util import JobSet, Workload
from dags.tpu_observability.configs.common import (
MachineConfigMap,
GCS_CONFIG_PATH,
GCS_JOBSET_CONFIG_PATH,
)
from dags.common.scheduling_helper.scheduling_helper import SchedulingHelper, get_dag_timeout
DAG_ID = "jobset_ttr_node_pool_resize"
DAGRUN_TIMEOUT = get_dag_timeout(DAG_ID)
SCHEDULE = SchedulingHelper.arrange_schedule_time(DAG_ID)
_DISK_SIZE_INCREMENT = 100
# Keyword arguments are generated dynamically at runtime (pylint does not
# know this signature).
with models.DAG( # pylint: disable=unexpected-keyword-arg
dag_id=DAG_ID,
start_date=datetime.datetime(2026, 1, 27),
schedule=SCHEDULE if composer_env.is_prod_env() else None,
dagrun_timeout=DAGRUN_TIMEOUT,
catchup=False,
tags=[
"cloud-ml-auto-solutions",
"jobset",
"time-to-recover",
"tpu-observability",
"node-pool-resize",
"TPU",
"v6e-16",
],
description=(
"This DAG tests the JobSet time-to-recover metric by triggering a "
"node pool disk resize, then polls the metric to check "
"if it is updated."
),
doc_md="""
# JobSet Time-To-Recover (TTR) Test Using Node Pool Disk Resize
### Description
This DAG verifies that JobSet can recover when the underlying node pool
undergoes a disruptive update (Disk Resize). It launches a JobSet,
increases the disk size of the node pool, and confirms that the
JobSet controller restarts the workload successfully.
### Prerequisites
This test requires an existing cluster to run.
### Procedures
First the node-pool is created, a jobset yaml is then launched on the
cluster and given a short period of time to initialize. After this a
node pool disk resize is triggered to interrupt the jobset. A sensor is
finally run which will poll Cloud Monitoring to detect that the jobset
time-to-recover (TTR) metric has been updated, resulting in a success,
or timeout, and fail.
""",
) as dag:
for machine in MachineConfigMap:
config = machine.value
# Keyword arguments are generated dynamically at runtime (pylint does not
# know this signature).
with TaskGroup( # pylint: disable=unexpected-keyword-arg
group_id=f"v{config.tpu_version.value}"
):
jobset_config = jobset.build_jobset_from_gcs_yaml(
gcs_path=GCS_JOBSET_CONFIG_PATH,
dag_name="jobset_ttr_node_pool_resize",
)
cluster_info = node_pool.build_node_pool_info_from_gcs_yaml.override(
task_id="build_node_pool_info_from_gcs_yaml"
)(
gcs_path=GCS_CONFIG_PATH,
dag_name="jobset_ttr_node_pool_resize",
is_prod=composer_env.is_prod_env(),
machine_type=config.machine_version.value,
tpu_topology=config.tpu_topology,
)
create_node_pool = node_pool.create.override(task_id="create_node_pool")(
node_pool=cluster_info,
)
start_workload = jobset.run_workload.override(task_id="start_workload")(
node_pool=cluster_info,
jobset_config=jobset_config,
workload_type=Workload.JAX_TPU_BENCHMARK,
)
ensure_all_pods_running = jobset.wait_for_all_pods_running.override(
task_id="ensure_all_pods_running"
)(
node_pool=cluster_info,
jobset_config=jobset_config,
)
node_pool_resize = node_pool.update.override(task_id="node_pool_resize")(
node_pool=cluster_info,
spec=node_pool.NodePoolUpdateSpec.DiskSize(
delta=_DISK_SIZE_INCREMENT
),
)
wait_for_metric_upload = jobset.wait_for_jobset_ttr_to_be_found.override(
task_id="wait_for_jobset_ttr_to_be_found"
)(
node_pool=cluster_info,
jobset_config=jobset_config,
)
cleanup_workload = jobset.end_workload.override(
task_id="cleanup_workload", trigger_rule=TriggerRule.ALL_DONE
)(
node_pool=cluster_info,
jobset_config=jobset_config,
).as_teardown(
setups=start_workload
)
cleanup_node_pool = node_pool.delete.override(
task_id="cleanup_node_pool", trigger_rule=TriggerRule.ALL_DONE
)(node_pool=cluster_info).as_teardown(
setups=create_node_pool,
)
chain(
jobset_config,
cluster_info,
create_node_pool,
start_workload,
ensure_all_pods_running,
node_pool_resize,
wait_for_metric_upload,
cleanup_workload,
cleanup_node_pool,
)