diff --git a/pathwaysutils/experimental/shared_pathways_service/__init__.py b/pathwaysutils/experimental/shared_pathways_service/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pathwaysutils/experimental/shared_pathways_service/gke_utils.py b/pathwaysutils/experimental/shared_pathways_service/gke_utils.py new file mode 100644 index 0000000..971e896 --- /dev/null +++ b/pathwaysutils/experimental/shared_pathways_service/gke_utils.py @@ -0,0 +1,323 @@ +"""GKE utils for deploying and managing the Pathways proxy.""" + +import logging +import socket +import subprocess +import urllib.parse + +import portpicker + +_logger = logging.getLogger(__name__) + +# TODO(b/456189271): Evaluate and replace the subprocess calls with Kubernetes +# Python API for kubectl calls. + + +def fetch_cluster_credentials( + *, cluster_name: str, project_id: str, location: str +) -> None: + """Fetches credentials for the GKE cluster.""" + _logger.info("Fetching credentials for '%s'.", cluster_name) + get_credentials_command = [ + "gcloud", + "container", + "clusters", + "get-credentials", + cluster_name, + f"--location={location}", + f"--project={project_id}", + ] + try: + subprocess.run( + get_credentials_command, + check=True, + capture_output=True, + text=True, + ) + except subprocess.CalledProcessError as e: + _logger.exception( + r"Failed to get cluster credentials. gcloud output:\n%r", e.stderr + ) + raise + + +def deploy_gke_yaml(yaml: str) -> None: + """Deploys the given YAML to the GKE cluster. + + Args: + yaml: The GKE YAML to deploy. + + Raises: + subprocess.CalledProcessError: If the kubectl command fails. + """ + _logger.info("Deploying GKE YAML: %s", yaml) + kubectl_apply_command = ["kubectl", "apply", "-f", "-"] + try: + proxy_result = subprocess.run( + kubectl_apply_command, + input=yaml, + check=True, + capture_output=True, + text=True, + ) + except subprocess.CalledProcessError as e: + _logger.exception( + r"Failed to deploy the GKE YAML. kubectl output:\n%r", e.stderr + ) + raise + + _logger.info( + "Successfully deployed the GKE YAML. %s", proxy_result.stdout + ) + + +def get_pod_from_job(job_name: str) -> str: + """Returns the pod name for the given job. + + Args: + job_name: The name of the job. + + Returns: + The name of the pod. + + Raises: + subprocess.CalledProcessError: If the kubectl command fails. + RuntimeError: If the pod is missing or the pod name is not in the expected + format. + """ + get_pod_command = [ + "kubectl", + "get", + "pods", + "-l", + f"job-name={job_name}", + "-o", + "name", + ] + try: + pod_result = subprocess.run( + get_pod_command, + check=True, + capture_output=True, + text=True, + ) + except subprocess.CalledProcessError as e: + _logger.exception( + r"Failed to get pod name. kubectl output:\n%r", e.stderr + ) + raise + + pod_name = pod_result.stdout.strip() + _logger.info("Pod name: %s", pod_name) + + if ( + not pod_name + or not pod_name.startswith("pod/") + or len(pod_name.split("/")) != 2 + ): + raise RuntimeError( + "Failed to get pod name. Expected format: pod/. Got:" + f" {pod_name}" + ) + + # pod_name is in the format of "pod/". We only need the pod name. + _, pod_name = pod_name.split("/") + return pod_name + + +def check_pod_ready(pod_name: str, timeout: int = 30) -> str: + """Checks if the given pod is ready. + + Args: + pod_name: The name of the pod. + timeout: The maximum time in seconds to wait for the pod to be ready. + + Returns: + The name of the pod. + + Raises: + RuntimeError: If the pod fails to become ready within the timeout. + """ + wait_command = [ + "kubectl", + "wait", + "--for=condition=Ready", + f"pod/{pod_name}", + f"--timeout={timeout}s", + ] + try: + subprocess.run(wait_command, check=True, capture_output=True, text=True) + except (subprocess.CalledProcessError, subprocess.TimeoutExpired) as e: + _logger.exception("Pod failed to become ready: %r", e) + + raise RuntimeError( + f"Pod did not become ready: {e.stderr}." + ) from e + except Exception as e: + _logger.exception("Error setting up the pod: %r", e) + raise + + _logger.info("Pod is ready: %s.", pod_name) + return pod_name + + +def get_log_link(*, cluster: str, project: str, job_name: str) -> str: + """Returns a link to Cloud Logging for the given cluster and job name.""" + log_filter = ( + 'resource.type="k8s_container"\n' + f'resource.labels.cluster_name="{cluster}"\n' + 'resource.labels.namespace_name="default"\n' + f'labels.k8s-pod/job-name="{job_name}"' + ) + encoded_filter = urllib.parse.quote(log_filter, safe="") + + return ( + "https://console.cloud.google.com/logs/query;" + f"query={encoded_filter};duration=PT1H" + f"?project={project}" + ) + + +def wait_for_pod(job_name: str) -> str: + """Waits for the given job's pod to be ready. + + Args: + job_name: The name of the job. + Returns: + The name of the pod. + Raises: + RuntimeError: If the pod is not ready. + """ + _logger.info("Waiting for pod to be created...") + pod_name = get_pod_from_job(job_name) + + _logger.info( + "Pod created: %s. Waiting for it to be ready...", pod_name + ) + + return check_pod_ready(pod_name) + + +def __test_pod_connection(port: int) -> None: + """Tests the connection to the pod. + + Args: + port: The port of the pod to connect to. + """ + _logger.info("Connecting to localhost:%d", port) + try: + with socket.create_connection(("localhost", port), timeout=30): + _logger.info("Pod is ready.") + except (socket.timeout, ConnectionRefusedError) as exc: + raise RuntimeError("Could not connect to the pod.") from exc + + +def enable_port_forwarding( + pod_name: str, + server_port: int, +) -> tuple[int, subprocess.Popen[str]]: + """Enables port forwarding for the given pod. + + Args: + pod_name: The name of the pod. + server_port: The port of the server to forward to. + + Returns: + A tuple containing the pod port and the port forwarding process. + Raises: + RuntimeError: If port forwarding fails to start or the pod connection + cannot be established. + """ + try: + port_available = portpicker.pick_unused_port() + except Exception as e: + _logger.exception("Error finding free local port: %r", e) + raise + + _logger.info("Found free local port: %d", port_available) + _logger.info( + "Starting port forwarding from local port %d to %s:%d", + port_available, + pod_name, + server_port, + ) + + port_forward_command = [ + "kubectl", + "port-forward", + "--address", + "localhost", + pod_name, + f"{port_available}:{server_port}", + ] + try: + # Start port forwarding in the background. + port_forward_process = subprocess.Popen( + port_forward_command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + except Exception as e: + _logger.exception("Error enabling port forwarding for the pod: %r", e) + raise + + # Check that the port forwarding is ready. + if port_forward_process.stdout is None: + _logger.error("Port-forward process stdout is None. Terminating.") + port_forward_process.terminate() + _, stderr = port_forward_process.communicate() + raise RuntimeError( + "Failed to start port forwarding: stdout not available.\n" + f"STDERR: {stderr}" + ) + + ready_line = port_forward_process.stdout.readline() + if "Forwarding from" in ready_line: + _logger.info("Port-forward is ready: %s", ready_line.strip()) + else: + # If the ready line is not found, the process might have exited with an + # error. We terminate it and raise an error with the stderr. + _logger.error("Port-forward process exited with error. Terminating.") + port_forward_process.terminate() + _, stderr = port_forward_process.communicate() + raise RuntimeError( + "Failed to start port forwarding.\n" + f"STDOUT: {port_forward_process.stdout}\n" + f"STDERR: {stderr}" + ) + + try: + __test_pod_connection(port_available) + except Exception: + port_forward_process.terminate() + raise + + return (port_available, port_forward_process) + + +def delete_gke_job(job_name: str) -> None: + """Deletes the given job from the GKE cluster. + + Args: + job_name: The name of the job. + """ + _logger.info("Deleting job: %s", job_name) + delete_job_command = [ + "kubectl", + "delete", + "job", + job_name, + "--ignore-not-found", + ] + try: + result = subprocess.run( + delete_job_command, + check=True, + capture_output=True, + text=True, + ) + except subprocess.CalledProcessError as e: + _logger.exception("Failed to delete job. kubectl output:\\n%r", e.stderr) + raise + _logger.info("Successfully deleted job. %s", result.stdout) diff --git a/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py b/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py new file mode 100644 index 0000000..3604503 --- /dev/null +++ b/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py @@ -0,0 +1,226 @@ +"""Module for connecting to a Pathways server for interactive supercomputing.""" + +from collections.abc import Iterator, Mapping +import contextlib +import logging +import os +import random +import string +import subprocess +from typing import Any + +import jax +import pathwaysutils +from pathwaysutils.experimental.shared_pathways_service import gke_utils +from pathwaysutils.experimental.shared_pathways_service import validators + + +PROXY_FILEPATH = os.path.join( + os.path.dirname(__file__), "yamls/pw-proxy.yaml" +) +# TODO(b/459935429): Hardcoding the port and using hostNetwork: true in the +# proxy YAML limits us to one proxy server pod per node. Consider alternative +# networking configurations to allow multiple proxies per node if needed. +PROXY_SERVER_PORT = 29_000 + +_JAX_PLATFORMS_KEY = "jax_platforms" +_JAX_PLATFORM_PROXY = "proxy" +_JAX_BACKEND_TARGET_KEY = "jax_backend_target" +_JAX_BACKEND_TARGET_HOSTNAME = "grpc://localhost" + +_logger = logging.getLogger(__name__) + + +def _deploy_pathways_proxy_server( + *, pathways_service: str, + proxy_job_name: str, + expected_instances: Mapping[Any, Any], + gcs_scratch_location: str, +) -> None: + """Deploys the Pathways proxy pods to the GKE cluster. + + Args: + pathways_service: The service name and port of the Pathways head. + proxy_job_name: The name to use for the deployed proxy. + expected_instances: A dictionary mapping instance types to the number of + instances. + gcs_scratch_location: The Google Cloud Storage location to use. + + Raises: + subprocess.CalledProcessError: If the kubectl command fails. + """ + try: + with open(PROXY_FILEPATH, "r") as f: + yaml_template = f.read() + except OSError as err: + raise ValueError("Could not read file: " + PROXY_FILEPATH) from err + + pathways_head_hostname, pathways_head_port = pathways_service.split(":") + + # Take the first instance type and count since we only support a single + # instance type for now. + instance_type, count = next(iter(expected_instances.items())) + instances_str = ",".join(instance_type for _ in range(count)) + + template = string.Template(yaml_template) + substituted_yaml = template.substitute( + PROXY_JOB_NAME=proxy_job_name, + PROXY_SERVER_PORT=PROXY_SERVER_PORT, + PATHWAYS_HEAD_HOSTNAME=pathways_head_hostname, + PATHWAYS_HEAD_PORT=pathways_head_port, + EXPECTED_INSTANCES=instances_str, + GCS_SCRATCH_LOCATION=gcs_scratch_location, + ) + + _logger.info("Deploying Pathways proxy: %s", proxy_job_name) + gke_utils.deploy_gke_yaml(substituted_yaml) + + _logger.info("Successfully deployed Pathways proxy.") + + +class _ISCPathways: + """Class for managing TPUs for interactive supercomputing. + + Attributes: + cluster: The name of the GKE cluster. + project: The GCP project ID. + region: The GCP region. + bucket: The Google Cloud Storage bucket to use. + pathways_service: The service name and port of the Pathways head pod. + expected_tpu_instances: A dictionary mapping TPU machine types to the number + of instances. + """ + + def __init__( + self, + *, cluster: str, + project: str, + region: str, + gcs_bucket: str, + pathways_service: str, + expected_tpu_instances: Mapping[Any, Any], + ): + """Initializes the TPU manager.""" + self.cluster = cluster + self.project = project + self.region = region + self.bucket = gcs_bucket + self.pathways_service = pathways_service + self.expected_tpu_instances = expected_tpu_instances + suffix = "".join( + random.choices(string.ascii_lowercase + string.digits, k=5) + ) + user = os.environ.get("USER", "user") + self._proxy_job_name = f"isc-proxy-{user}-{suffix}" + self._port_forward_process = None + self._proxy_port = None + + def __repr__(self): + return ( + f"_ISCPathways(cluster='{self.cluster}', project='{self.project}', " + f"region='{self.region}', bucket='{self.bucket}', " + f"pathways_service='{self.pathways_service}', " + f"expected_tpu_instances={self.expected_tpu_instances}, " + f"_proxy_job_name='{self._proxy_job_name}')" + ) + + def __enter__(self): + """Enters the context manager, ensuring cluster exists.""" + try: + _deploy_pathways_proxy_server( + pathways_service=self.pathways_service, + proxy_job_name=self._proxy_job_name, + expected_instances=self.expected_tpu_instances, + gcs_scratch_location=self.bucket, + ) + # Print a link to Cloud Logging + cloud_logging_link = gke_utils.get_log_link( + cluster=self.cluster, + project=self.project, + job_name=self._proxy_job_name, + ) + _logger.info("View proxy logs in Cloud Logging: %s", cloud_logging_link) + + proxy_pod = gke_utils.wait_for_pod(self._proxy_job_name) + self._proxy_port, self._port_forward_process = ( + gke_utils.enable_port_forwarding(proxy_pod, PROXY_SERVER_PORT) + ) + + # Update the JAX backend to use the proxy. + jax.config.update(_JAX_PLATFORMS_KEY, _JAX_PLATFORM_PROXY) + jax.config.update( + _JAX_BACKEND_TARGET_KEY, + f"{_JAX_BACKEND_TARGET_HOSTNAME}:{self._proxy_port}", + ) + pathwaysutils.initialize() + _logger.info( + "Interactive supercomputing proxy client ready for cluster '%s'.", + self.cluster, + ) + return self + except Exception as e: + _logger.exception("Error setting up Pathways proxy: %r", e) + # If any part of setup fails after deployment, cleanup. + self._cleanup() + raise + + def __exit__(self, exc_type, exc_value, traceback): + """Exits the context manager.""" + _logger.info("Exiting ISCPathways context.") + self._cleanup() + + def _cleanup(self): + """Cleans up resources created by the ISCPathways context.""" + if self._port_forward_process: + self._port_forward_process.terminate() + try: + self._port_forward_process.wait(timeout=10) + except subprocess.TimeoutExpired as e: + _logger.exception( + "Failed to terminate port forwarding process. Not treating as an " + "error: %r", + e, + ) + + _logger.info("Deleting Pathways proxy") + gke_utils.delete_gke_job(self._proxy_job_name) + + +@contextlib.contextmanager +def connect( + *, cluster: str, + project: str, + region: str, + gcs_bucket: str, + pathways_service: str, + expected_tpu_instances: Mapping[str, int], +) -> Iterator["_ISCPathways"]: + """Connects to a Pathways server if the cluster exists. If not, creates it. + + Args: + cluster: The name of the GKE cluster. + project: The GCP project ID. + region: The GCP region. + gcs_bucket: The Google Cloud Storage bucket to use for scratch space. + pathways_service: The service name and port of the Pathways head pod. + expected_tpu_instances: A dictionary mapping TPU machine types to the number + of instances. For example: {"tpuv6e:2x2": 2} + + Yields: + The Pathways manager. + """ + validators.validate_pathways_service(pathways_service) + validators.validate_tpu_instances(expected_tpu_instances) + gke_utils.fetch_cluster_credentials( + cluster_name=cluster, project_id=project, location=region + ) + _logger.info("Starting ISCPathways context.") + with _ISCPathways( + cluster=cluster, + project=project, + region=region, + gcs_bucket=gcs_bucket, + pathways_service=pathways_service, + expected_tpu_instances=expected_tpu_instances, + ) as t: + yield t diff --git a/pathwaysutils/experimental/shared_pathways_service/run_connect_example.py b/pathwaysutils/experimental/shared_pathways_service/run_connect_example.py new file mode 100644 index 0000000..1835571 --- /dev/null +++ b/pathwaysutils/experimental/shared_pathways_service/run_connect_example.py @@ -0,0 +1,57 @@ +"""Script to run JAX code on TPU with the Shared Pathways service.""" + +from collections.abc import Sequence +import pprint + +from absl import app +from absl import flags +import jax.numpy as jnp +from pathwaysutils.experimental.shared_pathways_service import isc_pathways + + +FLAGS = flags.FLAGS + +flags.DEFINE_string("cluster", None, "The name of the GKE cluster.") +flags.DEFINE_string("project", None, "The GCP project ID.") +flags.DEFINE_string("region", None, "The GCP region.") +flags.DEFINE_string("gcs_bucket", None, "The Google Cloud Storage bucket.") +flags.DEFINE_string( + "pathways_service", + None, + "The address and port of the Pathways Resource Manager.", +) +flags.DEFINE_string( + "tpu_type", "tpuv6e:2x2", "The TPU machine type and topology." +) +flags.DEFINE_integer("tpu_count", 1, "The number of TPU instances.") + +flags.mark_flags_as_required([ + "cluster", + "project", + "region", + "gcs_bucket", + "pathways_service", +]) + + +def main(argv: Sequence[str]) -> None: + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + with isc_pathways.connect( + cluster=FLAGS.cluster, + project=FLAGS.project, + region=FLAGS.region, + gcs_bucket=FLAGS.gcs_bucket, + pathways_service=FLAGS.pathways_service, + expected_tpu_instances={FLAGS.tpu_type: FLAGS.tpu_count}, + ): + orig_matrix = jnp.zeros(5) + result_matrix = orig_matrix + 1 + print("Original Random Matrix:") + pprint.pprint(orig_matrix) + print("\nMatrix after adding 1:") + pprint.pprint(result_matrix) + + +if __name__ == "__main__": + app.run(main) diff --git a/pathwaysutils/experimental/shared_pathways_service/validators.py b/pathwaysutils/experimental/shared_pathways_service/validators.py new file mode 100644 index 0000000..bd3e7e6 --- /dev/null +++ b/pathwaysutils/experimental/shared_pathways_service/validators.py @@ -0,0 +1,116 @@ +"""Validation functions for Shared Pathways Service.""" + +from collections.abc import Mapping +import logging +import re +from typing import Any + +_logger = logging.getLogger(__name__) + + +def validate_pathways_service(pathways_service: str) -> None: + """Validates the Pathways service name and port.""" + if not pathways_service: + raise ValueError("No Pathways service found.") + try: + pathways_head, pathways_head_port = pathways_service.split(":") + except ValueError as e: + raise ValueError( + f"pathways_service={pathways_service} is not in the expected format of" + " `:`" + ) from e + if not pathways_head.strip(): + raise ValueError( + f"pathways_service={pathways_service} contains an empty string for the" + " service name. Expected `:`" + ) + if not pathways_head_port.strip(): + raise ValueError( + f"pathways_service={pathways_service} contains an empty string for the" + " service port. Expected `:`" + ) + try: + int(pathways_head_port) + except ValueError as e: + raise ValueError( + f"pathways_service={pathways_service} contains a non-numeric service" + " port. Expected `:`" + ) from e + + +def _validate_tpu_supported(tpu_instance_with_topology: str) -> None: + """Checks if the given instance represents a valid single-host TPU. + + Args: + tpu_instance_with_topology: The TPU instance string, e.g., "tpuv6e:4x8". + + Raises ValueError if the instance is not a valid TPU host. + """ + # Mapping from Cloud TPU type prefix to max chips per host. + single_host_max_chips = { + "tpuv6e": 8, # Cloud TPU v6e (2x4) + } + + # Regex to extract topology + # Examples: + # ct5lp-hightpu-4t:4x8 -> ct5lp, 4x8 + # ct5p:2x2x1 -> ct5p, 2x2x1 + match = re.match( + r"^(?Ptpuv6e):(?P\d+(?:x\d+)*)$", + tpu_instance_with_topology, + ) + + if match: + tpu_base_type = match.group("type") + topology_str = match.group("topology") + + if not tpu_base_type: + raise ValueError( + f"Unknown TPU type '{type}' from '{tpu_instance_with_topology}'." + ) + + try: + dims = [int(d) for d in topology_str.split("x")] + if len(dims) < 2 or len(dims) > 3: + raise ValueError( + f"Error: Invalid topology format '{topology_str}', Expected either" + " 2 or 3 dimensions." + ) + num_chips = 1 + for dim in dims: + num_chips *= dim + except ValueError as exc: + raise ValueError( + f"Error: Invalid topology format '{topology_str}' in" + f" '{tpu_instance_with_topology}'." + ) from exc + + if num_chips > single_host_max_chips[tpu_base_type]: + raise ValueError( + f"Topology '{tpu_instance_with_topology}' exceeds" + f" {single_host_max_chips[tpu_base_type]}, the maximum supported" + f" chips for {tpu_base_type}." + ) + + return + + raise ValueError( + f"Unrecognized instance format: {tpu_instance_with_topology}." + ) + + +def validate_tpu_instances(expected_tpu_instances: Mapping[Any, Any]) -> None: + """Validates the instance list.""" + if not expected_tpu_instances: + raise ValueError("No instances found.") + for inst in expected_tpu_instances.keys(): + if not inst.strip(): + raise ValueError( + f"expected_tpu_instances={expected_tpu_instances} contains an " + "empty string for an instance name." + ) + if len(expected_tpu_instances.keys()) != 1: + raise ValueError("Only one machine type is supported at this time.") + + inst = next(iter(expected_tpu_instances.keys())) + _validate_tpu_supported(inst) diff --git a/pathwaysutils/experimental/shared_pathways_service/yamls/pw-proxy.yaml b/pathwaysutils/experimental/shared_pathways_service/yamls/pw-proxy.yaml new file mode 100644 index 0000000..38bc524 --- /dev/null +++ b/pathwaysutils/experimental/shared_pathways_service/yamls/pw-proxy.yaml @@ -0,0 +1,43 @@ +apiVersion: batch/v1 +kind: Job +metadata: + name: ${PROXY_JOB_NAME} +spec: + backoffLimit: 0 + completions: 1 + parallelism: 1 + template: + metadata: + labels: + app: pathways-proxy + spec: + automountServiceAccountToken: false + containers: + - name: pathways-proxy + image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest + imagePullPolicy: Always + args: + - --server_port=${PROXY_SERVER_PORT} + - --resource_manager_address=${PATHWAYS_HEAD_HOSTNAME}:${PATHWAYS_HEAD_PORT} + - --gcs_scratch_location=${GCS_SCRATCH_LOCATION} + - --virtual_slices=${EXPECTED_INSTANCES} + ports: + - containerPort: ${PROXY_SERVER_PORT} + protocol: TCP + resources: + limits: + cpu: "16" + memory: 100G + securityContext: + runAsUser: 1000 # go/gke-shipshape#rootless + runAsGroup: 1000 # go/gke-shipshape#rootless + readOnlyRootFilesystem: true # go/gke-shipshape#readonlyrootfs + capabilities: # go/gke-shipshape#capabilities + drop: + - ALL + seccompProfile: # go/gke-shipshape#seccomp + type: RuntimeDefault + allowPrivilegeEscalation: false # go/gke-shipshape#allowprivilegeescalation + dnsPolicy: ClusterFirstWithHostNet + hostNetwork: true + restartPolicy: OnFailure diff --git a/pathwaysutils/experimental/shared_pathways_service/yamls/pw-service-example.yaml b/pathwaysutils/experimental/shared_pathways_service/yamls/pw-service-example.yaml new file mode 100644 index 0000000..832418c --- /dev/null +++ b/pathwaysutils/experimental/shared_pathways_service/yamls/pw-service-example.yaml @@ -0,0 +1,13 @@ +apiVersion: pathways-job.pathways.domain/v1 +kind: PathwaysJob +metadata: + name: pathways-cluster # jobset name +spec: + maxRestarts: 1 + workers: # Modify this section to use your TPU type, topology, number of slices and the GCS bucket. + - type: ct6e-standard-4t + topology: 2x2 + numSlices: 2 + pathwaysDir: "gs://pathways-bucket" # Pre-create this bucket. + controller: + deploymentMode: default