Skip to content

Commit cab29a7

Browse files
authored
Add a new DAG that tests node pool TTR by killing process on pods (#1173)
This change introduces a new Airflow DAG, `jobset_ttr_kill_process`, designed to validate the Time-To-Recover (TTR) metrics for TPU JobSets. The DAG simulates a workload failure by injecting a fault (killing the main Python process) and monitors the system's ability to recover and log the recovery duration.
1 parent defcd3d commit cab29a7

File tree

2 files changed

+212
-6
lines changed

2 files changed

+212
-6
lines changed
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
A DAG to test jobset time-to-recover metric by
17+
killing the main process inside a worker Pod.
18+
"""
19+
20+
import datetime
21+
import logging
22+
import tempfile
23+
import os
24+
25+
from airflow import models
26+
from airflow.decorators import task
27+
from airflow.models.baseoperator import chain
28+
from airflow.utils.trigger_rule import TriggerRule
29+
from airflow.utils.task_group import TaskGroup
30+
31+
32+
from dags import composer_env
33+
from dags.tpu_observability.utils import jobset_util as jobset
34+
from dags.tpu_observability.utils import node_pool_util as node_pool
35+
from dags.tpu_observability.utils import subprocess_util as subprocess
36+
from dags.tpu_observability.utils.jobset_util import JobSet, Workload
37+
from dags.tpu_observability.configs.common import (
38+
MachineConfigMap,
39+
GCS_CONFIG_PATH,
40+
GCS_JOBSET_CONFIG_PATH,
41+
)
42+
43+
44+
@task
45+
def kill_tpu_pod_workload(info: node_pool.Info, pod_name: str) -> None:
46+
"""
47+
Kills the python process on a single pod.
48+
49+
This task retrieves cluster credentials, then attempts to kill the JAX
50+
python process inside the specified pod. It ignores errors if the pod
51+
has already been deleted to ensure pipeline continuity.
52+
"""
53+
with tempfile.NamedTemporaryFile() as temp_config_file:
54+
env = os.environ.copy()
55+
env["KUBECONFIG"] = temp_config_file.name
56+
57+
cmd = " && ".join([
58+
jobset.Command.get_credentials_command(info),
59+
f"kubectl exec {pod_name} -n default -- pkill -9 -f python",
60+
])
61+
62+
try:
63+
subprocess.run_exec(cmd, env=env)
64+
except subprocess.ProcessKilledException:
65+
logging.info("Process was terminated with SIGKILL")
66+
except Exception as e:
67+
raise e
68+
69+
70+
# Keyword arguments are generated dynamically at runtime (pylint does not
71+
# know this signature).
72+
with models.DAG( # pylint: disable=unexpected-keyword-arg
73+
dag_id="jobset_ttr_kill_process",
74+
start_date=datetime.datetime(2025, 8, 10),
75+
schedule="0 15 * * *" if composer_env.is_prod_env() else None,
76+
catchup=False,
77+
tags=[
78+
"cloud-ml-auto-solutions",
79+
"jobset",
80+
"time-to-recover",
81+
"tpu-observability",
82+
"kill-main-process",
83+
"TPU",
84+
"v6e-16",
85+
],
86+
description=(
87+
"This DAG tests the use of killing the main process inside a jobset "
88+
"pod to interrupt a jobset, then polls the jobset time-to-recover "
89+
"metric to check if it is updated."
90+
),
91+
doc_md="""
92+
# JobSet Time-To-Recover (TTR) Test by Killing Main Process
93+
94+
### Description
95+
This DAG validates the **Time-To-Recover (TTR)** metric by simulating a software-level failure.
96+
It provisions a TPU node pool, launches a JobSet workload, and then intentionally
97+
terminates the main Python process inside the worker Pods to trigger a recovery event.
98+
99+
### Prerequisites
100+
* Access to a GKE cluster with TPU support.
101+
* The `tpu-info` container image must be accessible by the cluster.
102+
* GCS configuration must be present at the defined `GCS_CONFIG_PATH`.
103+
104+
### Procedures
105+
1. **Environment Setup**: Dynamically builds node pool info and creates a dedicated TPU node pool.
106+
2. **Workload Launch**: Applies a JobSet YAML configured for JAX TPU benchmarks.
107+
3. **Fault Injection**: Once the job is started, the DAG executes `pkill -9 -f python`
108+
inside the worker Pods via `kubectl exec`. This simulates a crash of the main training process.
109+
4. **Metric Monitoring**: A sensor waits for the system to detect the failure, restart the
110+
workload, and successfully publish the `time-to-recover` metric.
111+
5. **Cleanup**: Automatically tears down the JobSet and deletes the TPU node pool to
112+
ensure no resource leakage, regardless of whether the test passed or failed.
113+
""",
114+
) as dag:
115+
for machine in MachineConfigMap:
116+
config = machine.value
117+
118+
# Keyword arguments are generated dynamically at runtime (pylint does not
119+
# know this signature).
120+
with TaskGroup( # pylint: disable=unexpected-keyword-arg
121+
group_id=f"v{config.tpu_version.value}"
122+
):
123+
jobset_config = jobset.build_jobset_from_gcs_yaml(
124+
gcs_path=GCS_JOBSET_CONFIG_PATH,
125+
dag_name="jobset_ttr_kill_process",
126+
)
127+
128+
cluster_info = node_pool.build_node_pool_info_from_gcs_yaml.override(
129+
task_id="build_node_pool_info_from_gcs_yaml"
130+
)(
131+
gcs_path=GCS_CONFIG_PATH,
132+
dag_name="jobset_ttr_kill_process",
133+
is_prod=composer_env.is_prod_env(),
134+
machine_type=config.machine_version.value,
135+
tpu_topology=config.tpu_topology,
136+
)
137+
138+
create_node_pool = node_pool.create.override(task_id="create_node_pool")(
139+
node_pool=cluster_info,
140+
)
141+
142+
apply_time = jobset.run_workload.override(task_id="run_workload")(
143+
node_pool=cluster_info,
144+
jobset_config=jobset_config,
145+
workload_type=Workload.JAX_TPU_BENCHMARK,
146+
)
147+
148+
pod_names = jobset.list_pod_names.override(task_id="list_pod_names")(
149+
node_pool=cluster_info,
150+
jobset_config=jobset_config,
151+
)
152+
153+
wait_for_job_start = jobset.wait_for_jobset_started.override(
154+
task_id="wait_for_job_start"
155+
)(cluster_info, pod_name_list=pod_names, job_apply_time=apply_time)
156+
157+
kill_tasks = (
158+
kill_tpu_pod_workload.override(task_id="kill_tpu_pod_workload")
159+
.partial(info=cluster_info)
160+
.expand(pod_name=pod_names)
161+
)
162+
163+
wait_for_metric_upload = jobset.wait_for_jobset_ttr_to_be_found.override(
164+
task_id="wait_for_metric_upload"
165+
)(
166+
node_pool=cluster_info,
167+
jobset_config=jobset_config,
168+
)
169+
170+
cleanup_workload = jobset.end_workload.override(
171+
task_id="cleanup_workload", trigger_rule=TriggerRule.ALL_DONE
172+
)(
173+
node_pool=cluster_info,
174+
jobset_config=jobset_config,
175+
).as_teardown(
176+
setups=apply_time
177+
)
178+
179+
cleanup_node_pool = node_pool.delete.override(
180+
task_id="cleanup_node_pool", trigger_rule=TriggerRule.ALL_DONE
181+
)(node_pool=cluster_info).as_teardown(
182+
setups=create_node_pool,
183+
)
184+
185+
chain(
186+
jobset_config,
187+
cluster_info,
188+
create_node_pool,
189+
apply_time,
190+
pod_names,
191+
wait_for_job_start,
192+
kill_tasks,
193+
wait_for_metric_upload,
194+
cleanup_workload,
195+
cleanup_node_pool,
196+
)

dags/tpu_observability/utils/subprocess_util.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,13 @@
2929
import logging
3030
import subprocess
3131

32-
from airflow.exceptions import AirflowFailException
32+
from airflow.exceptions import AirflowException
33+
34+
35+
class ProcessKilledException(AirflowException):
36+
"""Raised specifically when a command returns exit code 137 (SIGKILL)."""
37+
38+
pass
3339

3440

3541
def run_exec(
@@ -39,6 +45,7 @@ def run_exec(
3945
log_output: bool = True,
4046
) -> str:
4147
"""Executes a shell command and logs its output."""
48+
4249
if log_command:
4350
logging.info("[subprocess] executing command:\n %s\n", cmd)
4451

@@ -60,13 +67,16 @@ def run_exec(
6067
# (using the default system encoding).
6168
text=True,
6269
)
63-
6470
if res.returncode != 0:
6571
logging.info("[subprocess] stderr: %s", res.stderr)
66-
raise AirflowFailException(
67-
"Caught an error while executing a command. stderr Message:"
68-
f" {res.stderr}"
69-
)
72+
match res.returncode:
73+
case 137:
74+
raise ProcessKilledException()
75+
case _:
76+
raise AirflowException(
77+
f"Caught an error while executing a command. \n"
78+
f"stderr Message: {res.stderr}"
79+
)
7080

7181
if log_output:
7282
logging.info("[subprocess] stdout: %s", res.stdout)

0 commit comments

Comments
 (0)