Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions dags/tpu_observability/jobset_uptime_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@
@task
def get_current_time() -> TimeUtil:
"""Get the current time in UTC."""
current_time_utc = datetime.datetime.now(datetime.timezone.utc)
return TimeUtil.from_datetime(current_time_utc)
return TimeUtil.now()


# Keyword arguments are generated dynamically at runtime (pylint does not
Expand Down
45 changes: 17 additions & 28 deletions dags/tpu_observability/utils/jobset_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import enum
import dataclasses
import datetime
from datetime import timedelta
import json
import logging
import os
Expand All @@ -33,7 +33,6 @@
import kubernetes

from dags.tpu_observability.utils import subprocess_util as subprocess
from dags.tpu_observability.utils.gcp_util import query_time_series
from dags.tpu_observability.utils.gcp_util import list_time_series
from dags.tpu_observability.utils.node_pool_util import Info as node_pool_info
from dags.tpu_observability.utils.node_pool_util import NODE_POOL_SELECTOR_KEY
Expand All @@ -53,7 +52,7 @@ def generate_node_pool_selector(prefix: str) -> str:
Returns:
The selector value string (e.g., "rollback-20260212123456").
"""
run_id = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
run_id = TimeUtil.now().to_datetime().strftime("%Y%m%d%H%M%S")
return f"{prefix}-{run_id}"


Expand Down Expand Up @@ -436,7 +435,7 @@ def _generate_jobset_name(dag_id_prefix: str) -> str:
Returns:
A string representing the generated jobset name.
"""
now_utc = datetime.datetime.now(datetime.timezone.utc)
now_utc = TimeUtil.now().to_datetime()
timestamp = now_utc.strftime("%Y%m%d%H%M%S")
dag_id_prefix = dag_id_prefix.replace("_", "-").lower()

Expand Down Expand Up @@ -527,8 +526,7 @@ def run_workload(
"Logged JobSet metadata to XLML dashboard: %s", jobset_metadata
)

current_time_utc = datetime.datetime.now(datetime.timezone.utc)
return TimeUtil.from_datetime(current_time_utc)
return TimeUtil.now()


@task
Expand Down Expand Up @@ -681,11 +679,8 @@ def wait_for_jobset_started(
job_apply_time: The datetime object of the time the job was applied.
"""

end_time_datatime = job_apply_time.to_datetime() + datetime.timedelta(
minutes=10
)
start_time = job_apply_time
end_time = TimeUtil.from_datetime(end_time_datatime)
end_time = job_apply_time + timedelta(minutes=10)

if not pod_name_list:
raise AirflowFailException("pod_name_list is empty, sensor cannot proceed.")
Expand Down Expand Up @@ -752,11 +747,8 @@ def wait_for_jobset_ttr_to_be_found(
Returns:
bool: True if the TTR metric is found in Cloud Monitoring, False otherwise.
"""
now = datetime.datetime.now()
query_start = (
start_time
if start_time
else TimeUtil.from_datetime(now - datetime.timedelta(minutes=60))
start_time if start_time else TimeUtil.now() - timedelta(minutes=60)
)

time_series = list_time_series(
Expand All @@ -767,7 +759,7 @@ def wait_for_jobset_ttr_to_be_found(
f'resource.labels.entity_name="{jobset_config.jobset_name}"'
),
start_time=query_start,
end_time=TimeUtil.from_datetime(now),
end_time=TimeUtil.now(),
)

logging.info("Time series: %s", time_series)
Expand Down Expand Up @@ -808,21 +800,18 @@ def wait_for_all_pods_running(
def query_uptime_metrics(
node_pool: node_pool_info,
jobset_name: str,
start_time: datetime.datetime,
end_time: datetime.datetime,
start_time: TimeUtil,
end_time: TimeUtil,
):
"""Queries the JobSet's uptime metric from Cloud Monitoring."""
start_time = TimeUtil.from_datetime(start_time)
end_time = TimeUtil.from_datetime(end_time)

filter_string = [
'metric.type="kubernetes.io/jobset/uptime"',
f'resource.labels.project_id = "{node_pool.project_id}"',
f'resource.labels.cluster_name = "{node_pool.cluster_name}"',
f'resource.labels.entity_name = "{jobset_name}"',
]

return query_time_series(
return list_time_series(
project_id=node_pool.project_id,
filter_str=" AND ".join(filter_string),
start_time=start_time,
Expand All @@ -832,15 +821,15 @@ def query_uptime_metrics(
)


@task.sensor(poke_interval=30, timeout=3600, mode="reschedule")
@task.sensor(poke_interval=30, timeout=3600, mode="poke")
def wait_for_jobset_uptime_data(
node_pool: node_pool_info,
jobset_config: JobSet,
jobset_apply_time: TimeUtil,
):
"""Verify uptime data exists after jobset application."""
start_time = jobset_apply_time.to_datetime()
end_time = datetime.datetime.now(datetime.timezone.utc)
start_time = jobset_apply_time
end_time = TimeUtil.now()
data = query_uptime_metrics(
node_pool, jobset_config.jobset_name, start_time, end_time
)
Expand All @@ -851,16 +840,16 @@ def wait_for_jobset_uptime_data(
return False


@task.sensor(poke_interval=30, timeout=360, mode="reschedule")
@task.sensor(poke_interval=30, timeout=360, mode="poke")
def ensure_no_jobset_uptime_data(
node_pool: node_pool_info,
jobset_config: JobSet,
jobset_clear_time: TimeUtil,
wait_time_seconds: int,
):
"""Ensure no uptime data is recorded after jobset deletion."""
start_time = jobset_clear_time.to_datetime()
now = datetime.datetime.now(datetime.timezone.utc)
start_time = jobset_clear_time
now = TimeUtil.now()
data = query_uptime_metrics(
node_pool, jobset_config.jobset_name, start_time, now
)
Expand All @@ -869,7 +858,7 @@ def ensure_no_jobset_uptime_data(
if len(data) > 0:
raise AirflowFailException(f"Data detected: {data}")

if now - start_time >= datetime.timedelta(seconds=wait_time_seconds):
if (now - start_time).to_unix_seconds() >= wait_time_seconds:
logging.info("Stability period passed with no data detected.")
return True
return False
23 changes: 16 additions & 7 deletions dags/tpu_observability/utils/time_util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Utility class for handling various time representations."""

from __future__ import annotations

import datetime
from dataclasses import dataclass

Expand All @@ -13,23 +15,28 @@ class TimeUtil:
time: int

@classmethod
def from_iso_string(cls, time_str: str) -> "TimeUtil":
def now(cls) -> TimeUtil:
"""Returns the current time in UTC."""
return cls(int(datetime.datetime.now(datetime.timezone.utc).timestamp()))

@classmethod
def from_iso_string(cls, time_str: str) -> TimeUtil:
"""Builds a TimeUtil object from an ISO 8601 formatted string."""
dt_object = datetime.datetime.fromisoformat(time_str.replace("Z", "+00:00"))
return cls(int(dt_object.timestamp()))

@classmethod
def from_timestamp_pb2(cls, ts_pb: timestamp_pb2.Timestamp) -> "TimeUtil":
def from_timestamp_pb2(cls, ts_pb: timestamp_pb2.Timestamp) -> TimeUtil:
"""Builds a TimeUtil object from a Google Protobuf Timestamp."""
return cls(int(ts_pb.seconds))

@classmethod
def from_datetime(cls, dt: datetime.datetime) -> "TimeUtil":
def from_datetime(cls, dt: datetime.datetime) -> TimeUtil:
"""Builds a TimeUtil object from a standard datetime object."""
return cls(int(dt.timestamp()))

@classmethod
def from_unix_seconds(cls, unix_seconds: int | float) -> "TimeUtil":
def from_unix_seconds(cls, unix_seconds: int | float) -> TimeUtil:
"""Builds a TimeUtil object from a Unix timestamp (seconds)."""
return cls(int(unix_seconds))

Expand All @@ -52,16 +59,18 @@ def to_mql_string(self) -> str:
dt = self.to_datetime()
return dt.strftime("d'%Y/%m/%d-%H:%M:%S'")

def __add__(self, other: datetime.timedelta) -> "TimeUtil":
def __add__(self, other: datetime.timedelta) -> TimeUtil:
"""Allows usage like: TimeUtil(...) + timedelta(minutes=10)."""
if isinstance(other, datetime.timedelta):
return TimeUtil(self.time + int(other.total_seconds()))
return NotImplemented

def __sub__(self, other: datetime.timedelta) -> "TimeUtil":
"""Allows usage like: TimeUtil(...) - timedelta(minutes=10)."""
def __sub__(self, other: datetime.timedelta | TimeUtil) -> TimeUtil:
"""Allows usage like: TimeUtil(...) - timedelta(minutes=10) or TimeUtil(...) - TimeUtil(...)."""
if isinstance(other, datetime.timedelta):
return TimeUtil(self.time - int(other.total_seconds()))
if isinstance(other, TimeUtil):
return TimeUtil(self.time - other.time)
return NotImplemented


Expand Down
Loading