Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions dags/tpu_observability/jobset_ttr_kill_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,12 @@ def kill_tpu_pod_workload(info: node_pool.Info, pod_name: str) -> None:
with TaskGroup( # pylint: disable=unexpected-keyword-arg
group_id=f"v{config.tpu_version.value}"
):
selector = jobset.generate_node_pool_selector("jobset-ttr-kill-process")

jobset_config = jobset.build_jobset_from_gcs_yaml(
gcs_path=GCS_JOBSET_CONFIG_PATH,
dag_name="jobset_ttr_kill_process",
node_pool_selector=selector,
)

cluster_info = node_pool.build_node_pool_info_from_gcs_yaml.override(
Expand All @@ -140,6 +143,7 @@ def kill_tpu_pod_workload(info: node_pool.Info, pod_name: str) -> None:
is_prod=composer_env.is_prod_env(),
machine_type=config.machine_version.value,
tpu_topology=config.tpu_topology,
node_pool_selector=selector,
)

create_node_pool = node_pool.create.override(task_id="create_node_pool")(
Expand Down Expand Up @@ -190,6 +194,7 @@ def kill_tpu_pod_workload(info: node_pool.Info, pod_name: str) -> None:
)

chain(
selector,
jobset_config,
cluster_info,
create_node_pool,
Expand Down
7 changes: 7 additions & 0 deletions dags/tpu_observability/jobset_ttr_node_pool_resize.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,14 @@
with TaskGroup( # pylint: disable=unexpected-keyword-arg
group_id=f"v{config.tpu_version.value}"
):
selector = jobset.generate_node_pool_selector(
"jobset-ttr-node-pool-resize"
)

jobset_config = jobset.build_jobset_from_gcs_yaml(
gcs_path=GCS_JOBSET_CONFIG_PATH,
dag_name="jobset_ttr_node_pool_resize",
node_pool_selector=selector,
)

cluster_info = node_pool.build_node_pool_info_from_gcs_yaml.override(
Expand All @@ -102,6 +107,7 @@
is_prod=composer_env.is_prod_env(),
machine_type=config.machine_version.value,
tpu_topology=config.tpu_topology,
node_pool_selector=selector,
)

create_node_pool = node_pool.create.override(task_id="create_node_pool")(
Expand Down Expand Up @@ -151,6 +157,7 @@
)

chain(
selector,
jobset_config,
cluster_info,
create_node_pool,
Expand Down
8 changes: 7 additions & 1 deletion dags/tpu_observability/jobset_ttr_pod_delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,12 @@
with TaskGroup( # pylint: disable=unexpected-keyword-arg
group_id=f"v{config.tpu_version.value}"
):
selector = jobset.generate_node_pool_selector("jobset-ttr-pod-delete")

jobset_config = jobset.build_jobset_from_gcs_yaml(
gcs_path=GCS_JOBSET_CONFIG_PATH, dag_name="jobset_ttr_pod_delete"
gcs_path=GCS_JOBSET_CONFIG_PATH,
dag_name="jobset_ttr_pod_delete",
node_pool_selector=selector,
)

cluster_info = node_pool.build_node_pool_info_from_gcs_yaml.override(
Expand All @@ -98,6 +102,7 @@
is_prod=composer_env.is_prod_env(),
machine_type=config.machine_version.value,
tpu_topology=config.tpu_topology,
node_pool_selector=selector,
)

create_node_pool = node_pool.create.override(task_id="create_node_pool")(
Expand Down Expand Up @@ -144,6 +149,7 @@
)

chain(
selector,
jobset_config,
cluster_info,
create_node_pool,
Expand Down
8 changes: 7 additions & 1 deletion dags/tpu_observability/jobset_ttr_rollback.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,12 @@
with TaskGroup( # pylint: disable=unexpected-keyword-arg
group_id=f"v{config.tpu_version.value}"
):
selector = jobset.generate_node_pool_selector("jobset-rollback-ttr")

jobset_config = jobset.build_jobset_from_gcs_yaml(
gcs_path=GCS_JOBSET_CONFIG_PATH, dag_name="jobset_rollback_ttr"
gcs_path=GCS_JOBSET_CONFIG_PATH,
dag_name="jobset_rollback_ttr",
node_pool_selector=selector,
)

cluster_info = node_pool.build_node_pool_info_from_gcs_yaml.override(
Expand All @@ -100,6 +104,7 @@
is_prod=composer_env.is_prod_env(),
machine_type=config.machine_version.value,
tpu_topology=config.tpu_topology,
node_pool_selector=selector,
)

create_node_pool = node_pool.create.override(task_id="create_node_pool")(
Expand Down Expand Up @@ -146,6 +151,7 @@
)

chain(
selector,
jobset_config,
cluster_info,
create_node_pool,
Expand Down
7 changes: 7 additions & 0 deletions dags/tpu_observability/tpu_info_format_validation_dags.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,10 @@ def generate_second_node_pool_name(
"""Generates a second node pool name."""
return f"{node_pool_info.node_pool_name}-2"

selector = jobset.generate_node_pool_selector(
"tpu-info-format-validation-dag"
)

# Keyword arguments are generated dynamically at runtime (pylint does not
# know this signature).
with TaskGroup( # pylint: disable=unexpected-keyword-arg
Expand All @@ -354,6 +358,7 @@ def generate_second_node_pool_name(
jobset_config = jobset.build_jobset_from_gcs_yaml(
gcs_path=GCS_JOBSET_CONFIG_PATH,
dag_name="tpu_info_format_validation_dag",
node_pool_selector=selector,
)

cluster_info = node_pool.build_node_pool_info_from_gcs_yaml.override(
Expand All @@ -364,6 +369,7 @@ def generate_second_node_pool_name(
is_prod=composer_env.is_prod_env(),
machine_type=config.machine_version.value,
tpu_topology=config.tpu_topology,
node_pool_selector=selector,
)

cluster_info_2 = node_pool.copy_node_pool_info_with_override.override(
Expand Down Expand Up @@ -509,6 +515,7 @@ def generate_second_node_pool_name(
chain(cleanup_first_node_pool, cleanup_second_node_pool)

chain(
selector,
jobset_config,
cluster_info,
cluster_info_2,
Expand Down
7 changes: 7 additions & 0 deletions dags/tpu_observability/tpu_sdk_monitoring_validation_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,14 @@ def validate_monitoring_sdk(info: node_pool.Info, pod_name: str) -> None:
with TaskGroup( # pylint: disable=unexpected-keyword-arg
group_id=f"v{config.tpu_version.value}"
):
selector = jobset.generate_node_pool_selector(
"tpu-sdk-monitoring-validation"
)

jobset_config = jobset.build_jobset_from_gcs_yaml(
gcs_path=GCS_JOBSET_CONFIG_PATH,
dag_name="tpu_sdk_monitoring_validation",
node_pool_selector=selector,
)

cluster_info = node_pool.build_node_pool_info_from_gcs_yaml.override(
Expand All @@ -144,6 +149,7 @@ def validate_monitoring_sdk(info: node_pool.Info, pod_name: str) -> None:
is_prod=composer_env.is_prod_env(),
machine_type=config.machine_version.value,
tpu_topology=config.tpu_topology,
node_pool_selector=selector,
)

create_node_pool = node_pool.create.override(task_id="create_node_pool")(
Expand Down Expand Up @@ -191,6 +197,7 @@ def validate_monitoring_sdk(info: node_pool.Info, pod_name: str) -> None:
)

chain(
selector,
jobset_config,
cluster_info,
create_node_pool,
Expand Down
23 changes: 20 additions & 3 deletions dags/tpu_observability/utils/jobset_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,26 @@
from dags.tpu_observability.utils import subprocess_util as subprocess
from dags.tpu_observability.utils.gcp_util import query_time_series
from dags.tpu_observability.utils.node_pool_util import Info as node_pool_info
from dags.tpu_observability.utils.node_pool_util import NODE_POOL_SELECTOR_KEY
from dags.tpu_observability.utils.time_util import TimeUtil
from google.cloud.monitoring_v3 import types
import kubernetes
from xlml.apis import gcs
from xlml.utils import gke


@task
def generate_node_pool_selector(prefix: str) -> str:
"""Generates a unique node_pool_selector value.

Args:
prefix: An identifier for the workload type (e.g., "resize", "rollback").

Returns:
The selector value string (e.g., "rollback-20260212123456").
"""
run_id = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
return f"{prefix}-{run_id}"


class Workload:
"""A library of predefined workload scripts for JobSet.

Expand Down Expand Up @@ -129,7 +142,7 @@ def matmul_ultra_heavy(x, y):
# pylint: disable=line-too-long
_TEMPLATE = string.Template(
textwrap.dedent(
"""
f"""
apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
Expand All @@ -153,6 +166,7 @@ def matmul_ultra_heavy(x, y):
nodeSelector:
cloud.google.com/gke-tpu-accelerator: $tpu_accelerator_type
cloud.google.com/gke-tpu-topology: $tpu_topology
{NODE_POOL_SELECTOR_KEY}: $node_pool_selector
containers:
- name: $container_name
image: $image
Expand Down Expand Up @@ -212,6 +226,7 @@ class JobSet:
container_name: str
image: str
tpu_cores_per_pod: int
node_pool_selector: str

def generate_yaml(self, workload_script: Workload) -> str:
"""Generates the final JobSet YAML content.
Expand All @@ -226,6 +241,7 @@ def generate_yaml(self, workload_script: Workload) -> str:
params = dataclasses.asdict(self)
params["command"] = ["bash", "-c"]
params["args"] = workload_script
params["node_pool_selector"] = self.node_pool_selector or ""

return _TEMPLATE.substitute(params)

Expand Down Expand Up @@ -472,6 +488,7 @@ def run_workload(
Args:
node_pool: Configuration object with cluster details.
jobset_config: The JobSet object containing YAML configuration.
workload_type: The workload script to execute.
Returns:
The UTC time when the workload was started.
"""
Expand Down
21 changes: 20 additions & 1 deletion dags/tpu_observability/utils/node_pool_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,16 @@
from xlml.utils import composer


NODE_POOL_SELECTOR_KEY = "tpu-observability/workload"
"""The label key for binding JobSet workloads to specific GKE node pools.

This key is used as a Kubernetes node label to ensure pods are scheduled
on the correct node pool. It is applied to both:
- GKE node pools via `--node-labels` during creation
- JobSet YAML via `nodeSelector` to target the labeled nodes
"""


class Status(enum.Enum):
"""Enum for GKE node pool status."""

Expand Down Expand Up @@ -71,6 +81,7 @@ class Info:
num_nodes: int = None
tpu_topology: str = None
reservation: str = None
node_pool_selector: str = None


@task
Expand Down Expand Up @@ -183,7 +194,12 @@ def create(
node_pool: Info,
ignore_failure: bool = False,
) -> None:
"""Creates a GKE node pool by the given node pool information."""
"""Creates a GKE node pool by the given node pool information.

Args:
node_pool: The node pool configuration.
ignore_failure: If True, command failures are ignored.
"""

composer.log_metadata_for_xlml_dashboard({
"cluster_project": node_pool.project_id,
Expand Down Expand Up @@ -214,6 +230,9 @@ def create(
if node_pool.reservation:
command += f" --reservation-affinity=specific --reservation={node_pool.reservation}"

if node_pool.node_pool_selector:
command += f" --node-labels={NODE_POOL_SELECTOR_KEY}={node_pool.node_pool_selector}"

if ignore_failure:
command += "2>&1 || true "

Expand Down
Loading