forked from GoogleCloudPlatform/ml-auto-solutions
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathnode_pool_ttr_update_label.py
More file actions
127 lines (109 loc) · 4.4 KB
/
node_pool_ttr_update_label.py
File metadata and controls
127 lines (109 loc) · 4.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A DAG to validate GKE node pool Times To Recover(TTR) metrics by triggering a label update."""
import datetime
from airflow import models
from airflow.models.baseoperator import chain
from airflow.utils.task_group import TaskGroup
from airflow.utils.trigger_rule import TriggerRule
from dags import composer_env
from dags.tpu_observability.configs.common import MachineConfigMap, GCS_CONFIG_PATH
from dags.tpu_observability.utils import node_pool_util as node_pool
with models.DAG(
dag_id="node_pool_ttr_update_label",
start_date=datetime.datetime(2025, 9, 30),
schedule="30 21 * * *" if composer_env.is_prod_env() else None,
catchup=False,
tags=[
"gke",
"tpu-observability",
"node-pool-ttr-update-label",
"TPU",
"v6e-16",
],
description=(
"This DAG verifies the GKE node pool's Times To Recover(TTR) metrics "
"by triggering a label update and confirming the recovery time "
"is recorded"
),
doc_md="""
# GKE Node Pool Times To Recover(TTR) Metric Validation DAG
### Description
This DAG automates the validation of GKE node pool Times To Recover(TTR) metrics.
It creates a node pool and updates its labels then verifies that the TTR metric
is correctly generated and reported to Google Cloud Monitoring.
### Prerequisites
This test requires an existing GKE cluster.
### Procedures
1. Create a temporary node pool.
2. Wait for the node pool to be RUNNING.
3. Update the node pool label.
4. Wait for the Times To Recover(TTR) metrics to appear in Google Cloud Monitoring.
5. Clean up the node pool after the tests.
""",
) as dag:
for machine in MachineConfigMap:
config = machine.value
LABELS_TO_UPDATE = {"test_key": "test_val"}
with TaskGroup(group_id=f"v{config.tpu_version.value}"):
node_pool_info = node_pool.build_node_pool_info_from_gcs_yaml.override(
task_id="build_node_pool_info_from_gcs_yaml"
)(
gcs_path=GCS_CONFIG_PATH,
dag_name="node_pool_ttr_update_label",
is_prod=composer_env.is_prod_env(),
machine_type=config.machine_version.value,
tpu_topology=config.tpu_topology,
)
task_id = "create_node_pool"
create_node_pool = node_pool.create.override(task_id=task_id)(
node_pool=node_pool_info,
)
task_id = "wait_for_provisioning"
wait_for_provisioning = node_pool.wait_for_status.override(
task_id=task_id
)(node_pool=node_pool_info, status=node_pool.Status.PROVISIONING)
task_id = "wait_for_running"
wait_for_running = node_pool.wait_for_status.override(task_id=task_id)(
node_pool=node_pool_info, status=node_pool.Status.RUNNING
)
task_id = "update_node_pool_label"
update_node_pool_label = node_pool.update.override(task_id=task_id)(
node_pool=node_pool_info,
spec=node_pool.NodePoolUpdateSpec.Label(delta=LABELS_TO_UPDATE),
)
task_id = "wait_for_recovered"
wait_for_recovered = node_pool.wait_for_status.override(task_id=task_id)(
node_pool=node_pool_info, status=node_pool.Status.RUNNING
)
task_id = "wait_for_ttr"
wait_for_ttr = node_pool.wait_for_ttr.override(task_id=task_id)(
node_pool=node_pool_info, operation_start_time=update_node_pool_label
)
task_id = "cleanup_node_pool"
cleanup_node_pool = node_pool.delete.override(
task_id=task_id, trigger_rule=TriggerRule.ALL_DONE
)(node_pool=node_pool_info).as_teardown(
setups=create_node_pool,
)
chain(
node_pool_info,
create_node_pool,
wait_for_provisioning,
wait_for_running,
update_node_pool_label,
wait_for_recovered,
wait_for_ttr,
cleanup_node_pool,
)