Skip to content

Commit a8bb2af

Browse files
authored
Add a new DAG that tests the TPU monitoring SDK (#1153)
This change implements a new DAG `tpu_sdk_monitoring_validation`. It performs end-to-end functional validation of the `tpumonitoring` Python SDK inside TPU worker pods to ensure the observability stack is correctly configured and accessible on v6e slices.
1 parent 89360dd commit a8bb2af

File tree

2 files changed

+280
-0
lines changed

2 files changed

+280
-0
lines changed
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""A DAG to validate the `tpumonitoring` SDK, ensuring help() and
16+
list_supported_metrics() are functional inside TPU worker pods."""
17+
18+
import datetime
19+
20+
from airflow import models
21+
from airflow.models.baseoperator import chain
22+
from airflow.utils.trigger_rule import TriggerRule
23+
from airflow.utils.task_group import TaskGroup
24+
from airflow.decorators import task
25+
26+
from dags import composer_env
27+
from dags.tpu_observability.utils import jobset_util as jobset
28+
from dags.tpu_observability.utils import tpu_monitoring_sdk_util as sdk
29+
from dags.tpu_observability.utils import node_pool_util as node_pool
30+
from dags.tpu_observability.utils.jobset_util import JobSet, Workload
31+
from dags.tpu_observability.configs.common import (
32+
MachineConfigMap,
33+
GCS_CONFIG_PATH,
34+
)
35+
36+
37+
@task
38+
def validate_monitoring_sdk(info: node_pool.Info, pod_name: str) -> None:
39+
"""Validates the tpumonitoring SDK functions inside TPU worker pods.
40+
41+
This task executes both help() and list_supported_metrics() via the SDK
42+
and verifies that the output contains the expected strings and patterns.
43+
44+
Args:
45+
info: Cluster info for gcloud credentials.
46+
pod_name: Pod name provided by dynamic task mapping.
47+
"""
48+
# A dict of script to its expected result patterns.
49+
validate_spec: dict[sdk.TpuMonitoringScript, list[str]] = {
50+
# Validates help() output. Expected format:
51+
# - list_supported_metrics(): List all supported functionality...
52+
# - get_metric(metric_name:str): Get specific metric...
53+
# - snapshot mode: Enable real-time monitoring...
54+
sdk.TpuMonitoringScript.HELP: [
55+
"list_supported_metrics()",
56+
"get_metric(metric_name:str)",
57+
"snapshot mode",
58+
],
59+
# Validates list_supported_metrics() output. Expected format:
60+
# ['tensorcore_util', 'duty_cycle_pct', 'hbm_capacity_usage', ...]
61+
sdk.TpuMonitoringScript.LIST_SUPPORTED_METRICS: [
62+
"tensorcore_util",
63+
"duty_cycle_pct",
64+
"hbm_capacity_usage",
65+
"buffer_transfer_latency",
66+
"hlo_execution_timing",
67+
],
68+
}
69+
70+
for script, patterns in validate_spec.items():
71+
output = sdk.execute_sdk_command(info, pod_name, script)
72+
for pattern in patterns:
73+
if pattern not in output:
74+
raise AssertionError(
75+
f"Validation failed for 'tpumonitoring.{script.name.lower()}()': "
76+
f"Missing '{pattern}'."
77+
)
78+
79+
80+
with models.DAG(
81+
dag_id="tpu_sdk_monitoring_validation",
82+
start_date=datetime.datetime(2026, 1, 13),
83+
schedule="0 18 * * *" if composer_env.is_prod_env() else None,
84+
catchup=False,
85+
tags=[
86+
"cloud-ml-auto-solutions",
87+
"jobset",
88+
"tpu-observability",
89+
"TPU",
90+
"v6e-16",
91+
"tpu-monitoring-sdk",
92+
],
93+
description=(
94+
"Validates tpumonitoring SDK: help() and "
95+
"list_supported_metrics() inside TPU worker pods."
96+
),
97+
doc_md="""
98+
### Description
99+
This DAG performs an end-to-end validation of the `tpumonitoring` Python SDK
100+
within TPU worker pods. It ensures the SDK is correctly installed and its
101+
monitoring functions are accessible via `libtpu.sdk`.
102+
103+
### Validation Steps:
104+
1. **SDK Help Documentation Validation**:
105+
Executes `tpumonitoring.help()` to verify that the API documentation is
106+
correctly rendered and includes essential methods like `list_supported_metrics`.
107+
108+
2. **Metric Catalog Validation**:
109+
Executes `tpumonitoring.list_supported_metrics()` and verifies that
110+
core TPU metrics (e.g., `tensorcore_util`, `hbm_capacity_usage`, `ici_link_health`)
111+
are present in the returned list.
112+
113+
3. **Environment Integrity Check**:
114+
Ensures the `libtpu` library can correctly interface with the TPU driver
115+
and hardware devices inside the container.
116+
""",
117+
) as dag:
118+
for machine in MachineConfigMap:
119+
config = machine.value
120+
121+
jobset_config = JobSet(
122+
jobset_name="sdk-monitoring-v6e-workload",
123+
namespace="default",
124+
max_restarts=5,
125+
replicated_job_name="tpu-job-slice",
126+
replicas=1,
127+
backoff_limit=0,
128+
completions=4,
129+
parallelism=4,
130+
tpu_accelerator_type="tpu-v6e-slice",
131+
tpu_topology="4x4",
132+
container_name="jax-tpu-worker",
133+
image="asia-northeast1-docker.pkg.dev/cienet-cmcs/"
134+
"yuna-docker/tpu-info:v0.5.1",
135+
tpu_cores_per_pod=4,
136+
)
137+
138+
# Keyword arguments are generated dynamically at runtime (pylint does not
139+
# know this signature).
140+
with TaskGroup( # pylint: disable=unexpected-keyword-arg
141+
group_id=f"v{config.tpu_version.value}"
142+
):
143+
cluster_info = node_pool.build_node_pool_info_from_gcs_yaml.override(
144+
task_id="build_node_pool_info_from_gcs_yaml"
145+
)(
146+
gcs_path=GCS_CONFIG_PATH,
147+
dag_name="tpu_sdk_monitoring_validation",
148+
is_prod=composer_env.is_prod_env(),
149+
machine_type=config.machine_version.value,
150+
tpu_topology=config.tpu_topology,
151+
)
152+
153+
create_node_pool = node_pool.create.override(task_id="create_node_pool")(
154+
node_pool=cluster_info,
155+
)
156+
157+
apply_time = jobset.run_workload.override(task_id="run_workload")(
158+
node_pool=cluster_info,
159+
yaml_config=jobset_config.generate_yaml(
160+
workload_script=Workload.JAX_TPU_BENCHMARK
161+
),
162+
namespace=jobset_config.namespace,
163+
)
164+
165+
pod_names = jobset.list_pod_names.override(task_id="list_pod_names")(
166+
node_pool=cluster_info,
167+
namespace=jobset_config.namespace,
168+
)
169+
170+
wait_for_jobset_started = jobset.wait_for_jobset_started.override(
171+
task_id="wait_for_jobset_started"
172+
)(
173+
node_pool=cluster_info,
174+
pod_name_list=pod_names,
175+
job_apply_time=apply_time,
176+
)
177+
178+
sdk_validation = (
179+
validate_monitoring_sdk.override(task_id="sdk_validation")
180+
.partial(info=cluster_info)
181+
.expand(pod_name=pod_names)
182+
)
183+
184+
cleanup_workload = jobset.end_workload.override(
185+
task_id="cleanup_workload", trigger_rule=TriggerRule.ALL_DONE
186+
)(
187+
node_pool=cluster_info,
188+
jobset_name=jobset_config.jobset_name,
189+
namespace=jobset_config.namespace,
190+
).as_teardown(
191+
setups=apply_time
192+
)
193+
194+
cleanup_node_pool = node_pool.delete.override(
195+
task_id="cleanup_node_pool", trigger_rule=TriggerRule.ALL_DONE
196+
)(node_pool=cluster_info).as_teardown(
197+
setups=create_node_pool,
198+
)
199+
200+
chain(
201+
create_node_pool,
202+
apply_time,
203+
pod_names,
204+
wait_for_jobset_started,
205+
sdk_validation,
206+
cleanup_workload,
207+
cleanup_node_pool,
208+
)
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Utilities for executing Python commands within the TPU Monitoring SDK."""
16+
17+
import os
18+
import tempfile
19+
import textwrap
20+
21+
from dags.tpu_observability.utils import jobset_util as jobset
22+
from dags.tpu_observability.utils import node_pool_util as node_pool
23+
from dags.tpu_observability.utils import subprocess_util as subprocess
24+
25+
26+
class TpuMonitoringScript:
27+
"""Predefined Python scripts for TPU monitoring SDK."""
28+
29+
HELP = textwrap.dedent(
30+
"""
31+
from libtpu.sdk import tpumonitoring
32+
tpumonitoring.help()
33+
"""
34+
)
35+
36+
LIST_SUPPORTED_METRICS = textwrap.dedent(
37+
"""
38+
from libtpu.sdk import tpumonitoring
39+
print(tpumonitoring.list_supported_metrics())
40+
"""
41+
)
42+
43+
44+
def execute_sdk_command(
45+
info: node_pool.Info,
46+
pod_name: str,
47+
script: TpuMonitoringScript,
48+
namespace: str = "default",
49+
) -> str:
50+
"""Executes a predefined Python script inside a specific TPU pod via kubectl exec.
51+
52+
Args:
53+
info: Node pool and cluster information.
54+
pod_name: The name of the target pod.
55+
script: The Python script to run (use TpuMonitoringScript options).
56+
namespace: Kubernetes namespace.
57+
58+
Returns:
59+
The standard output of the executed command.
60+
"""
61+
with tempfile.NamedTemporaryFile() as temp_config_file:
62+
env = os.environ.copy()
63+
env["KUBECONFIG"] = temp_config_file.name
64+
65+
cmd = " && ".join([
66+
jobset.Command.get_credentials_command(info),
67+
(
68+
f"kubectl exec {pod_name} -n {namespace} "
69+
f"-- python3 -c '{script}'"
70+
),
71+
])
72+
return subprocess.run_exec(cmd, env=env)

0 commit comments

Comments
 (0)