Skip to content

Commit 02f90e3

Browse files
feat: Enhance healthiness and interruption reason for "PLX dashboard" (#1191)
Add additional metrics to the XLML metadata logging for JobSet observability. - `jobset_name`: add jobset name. - `pod_names`: List of pod names associated with the JobSet, captured before the interruption event. Co-authored-by: Chris liao <388chris@gmail.com>
1 parent 68f4a34 commit 02f90e3

File tree

5 files changed

+76
-28
lines changed

5 files changed

+76
-28
lines changed

dags/tpu_observability/jobset_ttr_kill_process.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -156,19 +156,25 @@ def kill_tpu_pod_workload(info: node_pool.Info, pod_name: str) -> None:
156156
workload_type=Workload.JAX_TPU_BENCHMARK,
157157
)
158158

159-
pod_names = jobset.list_pod_names.override(task_id="list_pod_names")(
159+
running_pods = jobset.wait_for_all_pods_running.override(
160+
task_id="ensure_all_pods_running"
161+
)(
160162
node_pool=cluster_info,
161163
jobset_config=jobset_config,
162164
)
163165

164166
wait_for_job_start = jobset.wait_for_jobset_started.override(
165167
task_id="wait_for_job_start"
166-
)(cluster_info, pod_name_list=pod_names, job_apply_time=apply_time)
168+
)(
169+
cluster_info,
170+
pod_name_list=running_pods,
171+
job_apply_time=apply_time,
172+
)
167173

168174
kill_tasks = (
169175
kill_tpu_pod_workload.override(task_id="kill_tpu_pod_workload")
170176
.partial(info=cluster_info)
171-
.expand(pod_name=pod_names)
177+
.expand(pod_name=running_pods)
172178
)
173179

174180
wait_for_metric_upload = jobset.wait_for_jobset_ttr_to_be_found.override(
@@ -199,7 +205,7 @@ def kill_tpu_pod_workload(info: node_pool.Info, pod_name: str) -> None:
199205
cluster_info,
200206
create_node_pool,
201207
apply_time,
202-
pod_names,
208+
running_pods,
203209
wait_for_job_start,
204210
kill_tasks,
205211
wait_for_metric_upload,

dags/tpu_observability/jobset_ttr_pod_delete.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@
5656
],
5757
description=(
5858
"This DAG tests the JobSet time-to-recover metric by deleting a random "
59-
"pod to trigger a recovery, then polls the metric to check if it is updated."
59+
"pod to trigger a recovery, then polls the metric to check if it is"
60+
" updated."
6061
),
6162
doc_md="""
6263
# JobSet Time-To-Recover (TTR) Test Using Random Pod Deletion

dags/tpu_observability/tpu_info_format_validation_dags.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -406,23 +406,25 @@ def generate_second_node_pool_name(
406406
workload_type=Workload.JAX_TPU_BENCHMARK,
407407
)
408408

409-
pod_names = jobset.list_pod_names.override(
410-
task_id="list_pod_names",
411-
retries=5,
412-
retry_delay=datetime.timedelta(seconds=10),
409+
running_pods = jobset.wait_for_all_pods_running.override(
410+
task_id="ensure_all_pods_running"
413411
)(
414412
node_pool=cluster_info,
415413
jobset_config=jobset_config,
416414
)
417415

418416
wait_for_job_start = jobset.wait_for_jobset_started.override(
419417
task_id="wait_for_job_start"
420-
)(cluster_info, pod_name_list=pod_names, job_apply_time=apply_time)
418+
)(
419+
cluster_info,
420+
pod_name_list=running_pods,
421+
job_apply_time=apply_time,
422+
)
421423

422424
outputs_of_tpu_info = (
423425
get_tpu_info_from_pod.override(task_id="get_tpu_info")
424426
.partial(info=cluster_info)
425-
.expand(pod_name=pod_names)
427+
.expand(pod_name=running_pods)
426428
)
427429

428430
output_of_tpu_info = (
@@ -521,7 +523,7 @@ def generate_second_node_pool_name(
521523
cluster_info_2,
522524
create_node_pool,
523525
apply_time,
524-
pod_names,
526+
running_pods,
525527
wait_for_job_start,
526528
outputs_of_tpu_info,
527529
output_of_tpu_info,

dags/tpu_observability/tpu_sdk_monitoring_validation_dag.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -162,23 +162,25 @@ def validate_monitoring_sdk(info: node_pool.Info, pod_name: str) -> None:
162162
workload_type=Workload.JAX_TPU_BENCHMARK,
163163
)
164164

165-
pod_names = jobset.list_pod_names.override(task_id="list_pod_names")(
165+
running_pods = jobset.wait_for_all_pods_running.override(
166+
task_id="ensure_all_pods_running"
167+
)(
166168
node_pool=cluster_info,
167169
jobset_config=jobset_config,
168170
)
169171

170172
wait_for_jobset_started = jobset.wait_for_jobset_started.override(
171173
task_id="wait_for_jobset_started"
172174
)(
173-
node_pool=cluster_info,
174-
pod_name_list=pod_names,
175+
cluster_info,
176+
pod_name_list=running_pods,
175177
job_apply_time=apply_time,
176178
)
177179

178180
sdk_validation = (
179181
validate_monitoring_sdk.override(task_id="sdk_validation")
180182
.partial(info=cluster_info)
181-
.expand(pod_name=pod_names)
183+
.expand(pod_name=running_pods)
182184
)
183185

184186
cleanup_workload = jobset.end_workload.override(
@@ -202,7 +204,7 @@ def validate_monitoring_sdk(info: node_pool.Info, pod_name: str) -> None:
202204
cluster_info,
203205
create_node_pool,
204206
apply_time,
205-
pod_names,
207+
running_pods,
206208
wait_for_jobset_started,
207209
sdk_validation,
208210
cleanup_workload,

dags/tpu_observability/utils/jobset_util.py

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
from airflow.decorators import task
3030
from airflow.exceptions import AirflowFailException
31+
from airflow.sensors.base import PokeReturnValue
3132
from google.cloud.monitoring_v3 import types
3233
import kubernetes
3334

@@ -38,6 +39,7 @@
3839
from dags.tpu_observability.utils.node_pool_util import NODE_POOL_SELECTOR_KEY
3940
from dags.tpu_observability.utils.time_util import TimeUtil
4041
from xlml.apis import gcs
42+
from xlml.utils import composer
4143
from xlml.utils import gke
4244

4345

@@ -507,6 +509,24 @@ def run_workload(
507509

508510
subprocess.run_exec(cmd, env=env)
509511

512+
# Log metadata for XLML dashboard
513+
# Pod names follow the pattern:
514+
# {jobset_name}-{replicated_job_name}-{job-index}-{pod-index}-{random}
515+
# The jobset_name prefix is stable across pod recreations, so a regex
516+
# pattern is more reliable than an exact pod name list.
517+
pod_name_pattern = f"{jobset_config.jobset_name}.*"
518+
jobset_metadata = {
519+
"project_id": node_pool.project_id,
520+
"cluster_name": node_pool.cluster_name,
521+
"node_pool_name": node_pool.node_pool_name,
522+
"jobset_name": jobset_config.jobset_name,
523+
"pod_name_pattern": pod_name_pattern,
524+
}
525+
composer.log_metadata_for_xlml_dashboard(jobset_metadata)
526+
logging.info(
527+
"Logged JobSet metadata to XLML dashboard: %s", jobset_metadata
528+
)
529+
510530
current_time_utc = datetime.datetime.now(datetime.timezone.utc)
511531
return TimeUtil.from_datetime(current_time_utc)
512532

@@ -724,7 +744,8 @@ def wait_for_jobset_ttr_to_be_found(
724744
725745
Args:
726746
node_pool (Info): An instance of the Info class containing GKE metadata.
727-
jobset_config: An instance of the JobSet class representing the jobset configuration.
747+
jobset_config: An instance of the JobSet class representing the jobset
748+
configuration.
728749
start_time (TimeUtil, optional): The UTC timestamp to start polling from.
729750
If not provided, defaults to 60 minutes before the current time.
730751
@@ -749,23 +770,39 @@ def wait_for_jobset_ttr_to_be_found(
749770
end_time=TimeUtil.from_datetime(now),
750771
)
751772

752-
# This function checks whether the TTR metric is present;
753-
# it does not assess its value.
754773
logging.info("Time series: %s", time_series)
755774
return len(time_series) > 0
756775

757776

758777
@task.sensor(poke_interval=30, timeout=600, mode="poke")
759-
def wait_for_all_pods_running(node_pool: node_pool_info, jobset_config: JobSet):
760-
num_running = len(
761-
get_running_pods(
762-
node_pool=node_pool,
763-
jobset_name=jobset_config.jobset_name,
764-
namespace="default",
765-
)
778+
def wait_for_all_pods_running(
779+
node_pool: node_pool_info, jobset_config: JobSet
780+
) -> PokeReturnValue:
781+
"""Waits for all pods to be running and returns the pod names.
782+
783+
Args:
784+
node_pool: The Info object containing the cluster information.
785+
jobset_config: The JobSet configuration.
786+
787+
Returns:
788+
PokeReturnValue with is_done=True and pod names when all pods are running,
789+
or is_done=False to continue polling.
790+
"""
791+
running_pods = get_running_pods(
792+
node_pool=node_pool,
793+
jobset_name=jobset_config.jobset_name,
794+
namespace="default",
766795
)
767796
num_pods = jobset_config.replicas * jobset_config.parallelism
768-
return num_running == num_pods
797+
if len(running_pods) == num_pods:
798+
logging.info(
799+
"All %d pods are running for JobSet '%s': %s",
800+
num_pods,
801+
jobset_config.jobset_name,
802+
running_pods,
803+
)
804+
return PokeReturnValue(is_done=True, xcom_value=running_pods)
805+
return PokeReturnValue(is_done=False)
769806

770807

771808
def query_uptime_metrics(

0 commit comments

Comments
 (0)