diff --git a/axlearn/cloud/gcp/job.py b/axlearn/cloud/gcp/job.py index b27c664e9..23952ab3e 100644 --- a/axlearn/cloud/gcp/job.py +++ b/axlearn/cloud/gcp/job.py @@ -5,6 +5,7 @@ See also ``On configuration`` in `axlearn/cloud/gcp/job.py`. """ +import enum import logging import shlex import subprocess @@ -19,6 +20,7 @@ from axlearn.cloud.common.utils import generate_job_name, subprocess_run from axlearn.cloud.gcp.config import default_env_id, default_project, default_zone from axlearn.cloud.gcp.jobset_utils import BaseReplicatedJob +from axlearn.cloud.gcp.k8s_service import LWSService from axlearn.cloud.gcp.lws_utils import BaseLeaderWorkerTemplate from axlearn.cloud.gcp.utils import ( custom_jobset_kwargs, @@ -30,6 +32,25 @@ from axlearn.common.utils import Nested +class _ServiceProtocol(enum.Enum): + + """https://kubernetes.io/docs/reference/networking/service-protocols/""" + + TCP = "TCP" + UDP = "UDP" + SCTP = "SCTP" + + +class _ServiceType(enum.Enum): + + """https://cloud.google.com/kubernetes-engine/docs/concepts/service#types-of-services sss""" + + CLUSTERIP = "ClusterIP" + NODEPORT = "NodePort" + LOADBALANCER = "LoadBalancer" + EXTERNALNAME = "ExternalName" + + class GCPJob(Job): """Base GCP Job definition.""" @@ -292,23 +313,73 @@ class Config(GCPJob.Config): namespace: str = "default" annotations: Optional[ConfigOr[dict]] = None num_replicas: int = 1 + enable_service: bool = False + port: int = None + targetport: int = None + service_type: str = None + protocol: str = None + service: Optional[LWSService.Config] = None @classmethod def set_defaults(cls, fv): super().set_defaults(fv) fv.set_default("max_tries", fv.max_tries or 10) fv.set_default("retry_interval", fv.retry_interval or 60) + fv.set_default("enable_service", fv.enable_service or False) + fv.set_default("targetport", fv.targetport or 9000) + fv.set_default("port", fv.port or 9000) + fv.set_default("protocol", fv.protocol or _ServiceProtocol.TCP.value) + fv.set_default("service_type", fv.service_type or _ServiceType.CLUSTERIP.value) @classmethod def define_flags(cls, fv: flags.FlagValues): super().define_flags(fv) common_kwargs = dict(flag_values=fv, allow_override=True) flags.DEFINE_string("name", None, "Name of the LeaderWorkerSet.", **common_kwargs) + flags.DEFINE_boolean( + "enable_service", + False, + "Whether to enable creation of service for LWS", + **common_kwargs, + ) + #### https://kubernetes.io/docs/reference/networking/service-protocols/ ##### + #### Available types: TCP, UDP, SCTP ##### + flags.DEFINE_enum( + "protocol", + None, + [v.value for v in _ServiceProtocol], + help="Protocol type of service for LWS", + flag_values=fv, + ) + ##### https://cloud.google.com/kubernetes-engine/docs/how-to/exposing-apps #### + ## Available types: ClusterIP(default), NodePort, LoadBalancer, ExternalName, Headless ## + flags.DEFINE_enum( + "service_type", + None, + [v.value for v in _ServiceType], + help="Service type for LWS", + flag_values=fv, + ) + flags.DEFINE_integer( + "port", + None, + "External port where application is exposed through service", + **common_kwargs, + ) + + flags.DEFINE_integer( + "targetport", None, " Application port which the service redirects to", **common_kwargs + ) @classmethod def from_flags(cls, fv: flags.FlagValues, **kwargs): cfg: GKELeaderWorkerSet.Config = super().from_flags(fv, **kwargs) cfg.num_replicas = fv.num_replicas + cfg.enable_service = fv.enable_service + cfg.port = fv.port + cfg.targetport = fv.targetport + cfg.protocol = fv.protocol + cfg.service_type = fv.service_type return cfg def __init__(self, cfg: Config, *, bundler: BaseDockerBundler): @@ -356,11 +427,19 @@ def _execute(self): **self._build_leaderworkerset(), ) logging.info("submitting LeaderWorkerSet: %s", custom_object) - return k8s.client.CustomObjectsApi().create_namespaced_custom_object( + lws_resp = k8s.client.CustomObjectsApi().create_namespaced_custom_object( namespace=cfg.namespace, body=custom_object, **api_kwargs, ) + #### Creating a Service ####### + if cfg.enable_service: + service_resp = cfg.service.instantiate().execute() + logging.info("Service created %s", str(service_resp)) + else: + cfg.service = None + + return lws_resp def exclusive_topology_annotations_leaderworkerset() -> dict: diff --git a/axlearn/cloud/gcp/k8s_service.py b/axlearn/cloud/gcp/k8s_service.py new file mode 100644 index 000000000..954b14f40 --- /dev/null +++ b/axlearn/cloud/gcp/k8s_service.py @@ -0,0 +1,162 @@ +""" k8s service module.""" +import copy +import logging +from typing import Any, Optional + +import kubernetes as k8s +from absl import flags + +from axlearn.cloud.common.utils import FlagConfigurable, generate_job_name +from axlearn.cloud.gcp.config import default_project +from axlearn.cloud.gcp.utils import custom_leaderworkerset_kwargs +from axlearn.common.config import REQUIRED, Required, config_class +from axlearn.common.utils import Nested + + +class Service(FlagConfigurable): + """Service interface""" + + @config_class + class Config(FlagConfigurable.Config): + """Configures Service + Attributes: + name: The name of LWS resource. + project: The poject to use within the k8s cluster. + """ + + name: Required[str] = REQUIRED + project: Required[str] = REQUIRED + + @classmethod + def define_flags(cls, fv: flags.FlagValues): + common_kwargs = dict(flag_values=fv, allow_override=True) + flags.DEFINE_string("name", None, "Name of the service.", **common_kwargs) + flags.DEFINE_string("project", None, "The GCP project name.", **common_kwargs) + + @classmethod + def set_defaults(cls, fv: flags.FlagValues): + fv.set_default("name", fv.name or generate_job_name()) + fv.set_default("project", default_project()) + + def _delete(self): + """Cleans up the service. Called on termination when all retries are exhausted. + + Note that `_delete` is not called when `_execute` finishes successfully. It is up + to the implementation of `_execute` to clean up properly. + """ + + def _build_service(self) -> Any: + """Performs some computation. The return value can be implementation dependent.""" + raise NotImplementedError(type(self)) + + +class LWSService(Service): + """LWS Service""" + + @config_class + class Config(Service.Config): + """Configures Service + Attributes: + namespace: The namespace to use within the k8s cluster. + protocol: protocol for service , ex: TCP, HTTP + port: the exposed port of service + targetport: the application port of leader pod + service_type: Type of Service , ex: ClusterIP + """ + + namespace: str = None + protocol: Optional[str] = None + port: Optional[int] = None + targetport: Optional[int] = None + service_type: Optional[str] = None + + @classmethod + def define_flags(cls, fv: flags.FlagValues): + super().define_flags(fv) + common_kwargs = dict(flag_values=fv, allow_override=True) + flags.DEFINE_string("name", None, "Name of the service.", **common_kwargs) + flags.DEFINE_string("namespace", None, "Namespace of the service.", **common_kwargs) + flags.DEFINE_string("protocol", None, "Protocol of the service.", **common_kwargs) + flags.DEFINE_string("service_type", None, "Type of the service.", **common_kwargs) + flags.DEFINE_integer("port", None, "Port of the service.", **common_kwargs) + flags.DEFINE_integer("targetport", None, "TargetPort of the service.", **common_kwargs) + + @classmethod + def set_defaults(cls, fv: flags.FlagValues): + super().set_defaults(fv) + fv.set_default("namespace", fv.namespace or "default") + fv.set_default("protocol", fv.protocol or "TCP") + fv.set_default("service_type", fv.service_type or "ClusterIP") + fv.set_default("port", fv.port or 9000) + fv.set_default("targetport", fv.targetport or 9000) + + @classmethod + def default_config(cls): + return super().default_config() + + def __init__(self, cfg: Config): + super().__init__(cfg) + logging.info("LWSService class init") + self._config = copy.deepcopy(cfg) + self.name = cfg.name + "-service" + self.protocol = cfg.protocol + self.port = cfg.port + self.targetport = cfg.targetport + self.service_type = cfg.service_type + self.label_name = cfg.name + + def _build_service(self) -> Nested[Any]: + """ + Builds a config for a Service + Returns: + A nested dict corresponding to a k8s Service config + """ + logging.info("LWSservice class build") + logging.info(str(self.config)) + api_kwargs = custom_leaderworkerset_kwargs() + + namespace = "default" + group = api_kwargs["group"] + version = api_kwargs["version"] + plural = api_kwargs["plural"] + lws_name = self.name.split("-service")[0] + custom_api = k8s.client.CustomObjectsApi() + + # Fetch the CR object + lws = custom_api.get_namespaced_custom_object( + group=group, version=version, namespace=namespace, plural=plural, name=lws_name + ) + + return dict( + metadata=k8s.client.V1ObjectMeta( + name=self.name, + owner_references=[ + k8s.client.V1OwnerReference( + api_version=f"{api_kwargs['group']}/{api_kwargs['version']}", + kind="LeaderWorkerSet", + name=lws_name, ### self.name is a name+"-service" + uid=lws["metadata"]["uid"], + controller=True, + block_owner_deletion=True, + ) + ], + ), + spec=k8s.client.V1ServiceSpec( + selector={"app": self.label_name}, + ports=[ + k8s.client.V1ServicePort( + protocol=self.protocol, + port=self.port, + target_port=self.targetport, + ) + ], + type=self.service_type, + ), + ) + + def execute(self): + logging.info("LWSservice class execute") + service = self._build_service() + logging.info("Submitting LWSservice body=%s ", service) + v1 = k8s.client.CoreV1Api() + return v1.create_namespaced_service(namespace=self.config.namespace, body=service) diff --git a/axlearn/cloud/gcp/k8s_service_test.py b/axlearn/cloud/gcp/k8s_service_test.py new file mode 100644 index 000000000..c52dd493e --- /dev/null +++ b/axlearn/cloud/gcp/k8s_service_test.py @@ -0,0 +1,51 @@ +"""Tests k8s service module.""" + +from absl import flags + +from axlearn.cloud.common.utils import define_flags, from_flags +from axlearn.cloud.gcp import k8s_service +from axlearn.common.test_utils import TestCase + +FLAGS = flags.FLAGS + + +class GKELWSService(TestCase): + """Tests GKEService with LWS(TPU).""" + + def _service_config( + self, + *, + command: str, + **kwargs, + ) -> k8s_service.LWSService.Config: + fv = flags.FlagValues() + cfg = k8s_service.LWSService.default_config().set() + + define_flags(cfg, fv) + print(kwargs) + for key, value in kwargs.items(): + print(key, value) + if value is not None: + # Use setattr rather than set_default to set flags. + setattr(fv, key, value) + fv.name = "fake-name" + fv.project = "fake-project" + fv.mark_as_parsed() + cfg = from_flags(cfg, fv, command=command) + # Test that retries are configured on fv by default. + self.assertIsNotNone(fv["name"]) + return cfg + + def test_instantiate( + self, + ): + cfg = self._service_config( + command="test-command", + project="fake-project", + port=9000, + ) + self.assertIsInstance(cfg, k8s_service.LWSService.Config) + self.assertEqual(cfg.project, "fake-project") + gke_lws_service = cfg.set().instantiate() + self.assertEqual(cfg.name + "-service", gke_lws_service.name) + self.assertEqual(cfg.port, gke_lws_service.port) diff --git a/axlearn/cloud/gcp/pathways_utils.py b/axlearn/cloud/gcp/pathways_utils.py index 6824465ab..f8285ed1d 100644 --- a/axlearn/cloud/gcp/pathways_utils.py +++ b/axlearn/cloud/gcp/pathways_utils.py @@ -732,6 +732,8 @@ class Config(BaseLeaderWorkerTemplate.Config): pathways_xla_flags: list[str] = [] pathways_head_cpu: Optional[str] = None pathways_head_mem: Optional[str] = None + targetport: Optional[int] = None + enable_service: bool = None @classmethod def define_flags(cls, fv): @@ -760,12 +762,26 @@ def define_flags(cls, fv): "Memory request for pathways-head container in GiB. Default is 16GiB", **common_kwargs, ) + flags.DEFINE_boolean( + "enable_service", + False, + "Whether to enable creation of service for LWS", + **common_kwargs, + ) + flags.DEFINE_integer( + "targetport", + None, + "port where a service can access application, set at head container", + **common_kwargs, + ) @classmethod def set_defaults(cls, fv): super().set_defaults(fv) fv.set_default("pathways_head_cpu", fv.pathways_head_cpu or "1") fv.set_default("pathways_head_mem", fv.pathways_head_mem or "16") + fv.set_default("targetport", fv.targetport or 9000) + fv.set_default("enable_service", fv.enable_service or False) @classmethod def default_config(cls): @@ -907,6 +923,9 @@ def _build_head_container(self) -> dict: ], imagePullPolicy="Always", resources=resources, + ports=[dict(containerPort=self.config.targetport)] + if self.config.enable_service + else [], ) def build_leader_pod(self) -> Nested[Any]: @@ -919,6 +938,7 @@ def build_leader_pod(self) -> Nested[Any]: labels.update({BASTION_JOB_VERSION_LABEL: os.environ.get(BASTION_JOB_VERSION_ENV_VAR)}) volumes.append(dict(name="shared-output", emptyDir={})) + labels = {"app": cfg.name} if cfg.gcsfuse_mount: annotations.update( diff --git a/axlearn/cloud/gcp/runners/__init__.py b/axlearn/cloud/gcp/runners/__init__.py index 8c4cb35af..45e224372 100644 --- a/axlearn/cloud/gcp/runners/__init__.py +++ b/axlearn/cloud/gcp/runners/__init__.py @@ -19,6 +19,7 @@ A4HighReplicatedJob, TPUReplicatedJob, ) +from axlearn.cloud.gcp.k8s_service import LWSService from axlearn.cloud.gcp.node_pool_provisioner import TPUNodePoolProvisioner from axlearn.cloud.gcp.pathways_utils import ( PathwaysLeaderWorkerTemplate, @@ -63,6 +64,7 @@ def named_runner_configs( inner=GKELeaderWorkerSet.default_config().set( builder=PathwaysLeaderWorkerTemplate.default_config(), annotations=config_for_function(exclusive_topology_annotations_leaderworkerset), + service=LWSService.default_config(), ), pre_provisioner=TPUNodePoolProvisioner.default_config(), ),