Skip to content

Muyang yu/lwx service copy #1330

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 80 additions & 1 deletion axlearn/cloud/gcp/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
See also ``On configuration`` in `axlearn/cloud/gcp/job.py`.
"""

import enum
import logging
import shlex
import subprocess
Expand All @@ -19,6 +20,7 @@
from axlearn.cloud.common.utils import generate_job_name, subprocess_run
from axlearn.cloud.gcp.config import default_env_id, default_project, default_zone
from axlearn.cloud.gcp.jobset_utils import BaseReplicatedJob
from axlearn.cloud.gcp.k8s_service import LWSService
from axlearn.cloud.gcp.lws_utils import BaseLeaderWorkerTemplate
from axlearn.cloud.gcp.utils import (
custom_jobset_kwargs,
Expand All @@ -30,6 +32,25 @@
from axlearn.common.utils import Nested


class _ServiceProtocol(enum.Enum):

"""https://kubernetes.io/docs/reference/networking/service-protocols/"""

TCP = "TCP"
UDP = "UDP"
SCTP = "SCTP"


class _ServiceType(enum.Enum):

"""https://cloud.google.com/kubernetes-engine/docs/concepts/service#types-of-services sss"""

CLUSTERIP = "ClusterIP"
NODEPORT = "NodePort"
LOADBALANCER = "LoadBalancer"
EXTERNALNAME = "ExternalName"


class GCPJob(Job):
"""Base GCP Job definition."""

Expand Down Expand Up @@ -292,23 +313,73 @@ class Config(GCPJob.Config):
namespace: str = "default"
annotations: Optional[ConfigOr[dict]] = None
num_replicas: int = 1
enable_service: bool = False
port: int = None
targetport: int = None
service_type: str = None
protocol: str = None
service: Optional[LWSService.Config] = None

@classmethod
def set_defaults(cls, fv):
super().set_defaults(fv)
fv.set_default("max_tries", fv.max_tries or 10)
fv.set_default("retry_interval", fv.retry_interval or 60)
fv.set_default("enable_service", fv.enable_service or False)
fv.set_default("targetport", fv.targetport or 9000)
fv.set_default("port", fv.port or 9000)
fv.set_default("protocol", fv.protocol or _ServiceProtocol.TCP.value)
fv.set_default("service_type", fv.service_type or _ServiceType.CLUSTERIP.value)

@classmethod
def define_flags(cls, fv: flags.FlagValues):
super().define_flags(fv)
common_kwargs = dict(flag_values=fv, allow_override=True)
flags.DEFINE_string("name", None, "Name of the LeaderWorkerSet.", **common_kwargs)
flags.DEFINE_boolean(
"enable_service",
False,
"Whether to enable creation of service for LWS",
**common_kwargs,
)
#### https://kubernetes.io/docs/reference/networking/service-protocols/ #####
#### Available types: TCP, UDP, SCTP #####
flags.DEFINE_enum(
"protocol",
None,
[v.value for v in _ServiceProtocol],
help="Protocol type of service for LWS",
flag_values=fv,
)
##### https://cloud.google.com/kubernetes-engine/docs/how-to/exposing-apps ####
## Available types: ClusterIP(default), NodePort, LoadBalancer, ExternalName, Headless ##
flags.DEFINE_enum(
"service_type",
None,
[v.value for v in _ServiceType],
help="Service type for LWS",
flag_values=fv,
)
flags.DEFINE_integer(
"port",
None,
"External port where application is exposed through service",
**common_kwargs,
)

flags.DEFINE_integer(
"targetport", None, " Application port which the service redirects to", **common_kwargs
)

@classmethod
def from_flags(cls, fv: flags.FlagValues, **kwargs):
cfg: GKELeaderWorkerSet.Config = super().from_flags(fv, **kwargs)
cfg.num_replicas = fv.num_replicas
cfg.enable_service = fv.enable_service
cfg.port = fv.port
cfg.targetport = fv.targetport
cfg.protocol = fv.protocol
cfg.service_type = fv.service_type
return cfg

def __init__(self, cfg: Config, *, bundler: BaseDockerBundler):
Expand Down Expand Up @@ -356,11 +427,19 @@ def _execute(self):
**self._build_leaderworkerset(),
)
logging.info("submitting LeaderWorkerSet: %s", custom_object)
return k8s.client.CustomObjectsApi().create_namespaced_custom_object(
lws_resp = k8s.client.CustomObjectsApi().create_namespaced_custom_object(
namespace=cfg.namespace,
body=custom_object,
**api_kwargs,
)
#### Creating a Service #######
if cfg.enable_service:
service_resp = cfg.service.instantiate().execute()
logging.info("Service created %s", str(service_resp))
else:
cfg.service = None

return lws_resp


def exclusive_topology_annotations_leaderworkerset() -> dict:
Expand Down
162 changes: 162 additions & 0 deletions axlearn/cloud/gcp/k8s_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
""" k8s service module."""
import copy
import logging
from typing import Any, Optional

import kubernetes as k8s
from absl import flags

from axlearn.cloud.common.utils import FlagConfigurable, generate_job_name
from axlearn.cloud.gcp.config import default_project
from axlearn.cloud.gcp.utils import custom_leaderworkerset_kwargs
from axlearn.common.config import REQUIRED, Required, config_class
from axlearn.common.utils import Nested


class Service(FlagConfigurable):
"""Service interface"""

@config_class
class Config(FlagConfigurable.Config):
"""Configures Service
Attributes:
name: The name of LWS resource.
project: The poject to use within the k8s cluster.
"""

name: Required[str] = REQUIRED
project: Required[str] = REQUIRED

@classmethod
def define_flags(cls, fv: flags.FlagValues):
common_kwargs = dict(flag_values=fv, allow_override=True)
flags.DEFINE_string("name", None, "Name of the service.", **common_kwargs)
flags.DEFINE_string("project", None, "The GCP project name.", **common_kwargs)

@classmethod
def set_defaults(cls, fv: flags.FlagValues):
fv.set_default("name", fv.name or generate_job_name())
fv.set_default("project", default_project())

def _delete(self):
"""Cleans up the service. Called on termination when all retries are exhausted.

Note that `_delete` is not called when `_execute` finishes successfully. It is up
to the implementation of `_execute` to clean up properly.
"""

def _build_service(self) -> Any:
"""Performs some computation. The return value can be implementation dependent."""
raise NotImplementedError(type(self))


class LWSService(Service):
"""LWS Service"""

@config_class
class Config(Service.Config):
"""Configures Service
Attributes:
namespace: The namespace to use within the k8s cluster.
protocol: protocol for service , ex: TCP, HTTP
port: the exposed port of service
targetport: the application port of leader pod
service_type: Type of Service , ex: ClusterIP
"""

namespace: str = None
protocol: Optional[str] = None
port: Optional[int] = None
targetport: Optional[int] = None
service_type: Optional[str] = None

@classmethod
def define_flags(cls, fv: flags.FlagValues):
super().define_flags(fv)
common_kwargs = dict(flag_values=fv, allow_override=True)
flags.DEFINE_string("name", None, "Name of the service.", **common_kwargs)
flags.DEFINE_string("namespace", None, "Namespace of the service.", **common_kwargs)
flags.DEFINE_string("protocol", None, "Protocol of the service.", **common_kwargs)
flags.DEFINE_string("service_type", None, "Type of the service.", **common_kwargs)
flags.DEFINE_integer("port", None, "Port of the service.", **common_kwargs)
flags.DEFINE_integer("targetport", None, "TargetPort of the service.", **common_kwargs)

@classmethod
def set_defaults(cls, fv: flags.FlagValues):
super().set_defaults(fv)
fv.set_default("namespace", fv.namespace or "default")
fv.set_default("protocol", fv.protocol or "TCP")
fv.set_default("service_type", fv.service_type or "ClusterIP")
fv.set_default("port", fv.port or 9000)
fv.set_default("targetport", fv.targetport or 9000)

@classmethod
def default_config(cls):
return super().default_config()

def __init__(self, cfg: Config):
super().__init__(cfg)
logging.info("LWSService class init")
self._config = copy.deepcopy(cfg)
self.name = cfg.name + "-service"
self.protocol = cfg.protocol
self.port = cfg.port
self.targetport = cfg.targetport
self.service_type = cfg.service_type
self.label_name = cfg.name

def _build_service(self) -> Nested[Any]:
"""
Builds a config for a Service
Returns:
A nested dict corresponding to a k8s Service config
"""
logging.info("LWSservice class build")
logging.info(str(self.config))
api_kwargs = custom_leaderworkerset_kwargs()

namespace = "default"
group = api_kwargs["group"]
version = api_kwargs["version"]
plural = api_kwargs["plural"]
lws_name = self.name.split("-service")[0]
custom_api = k8s.client.CustomObjectsApi()

# Fetch the CR object
lws = custom_api.get_namespaced_custom_object(
group=group, version=version, namespace=namespace, plural=plural, name=lws_name
)

return dict(
metadata=k8s.client.V1ObjectMeta(
name=self.name,
owner_references=[
k8s.client.V1OwnerReference(
api_version=f"{api_kwargs['group']}/{api_kwargs['version']}",
kind="LeaderWorkerSet",
name=lws_name, ### self.name is a name+"-service"
uid=lws["metadata"]["uid"],
controller=True,
block_owner_deletion=True,
)
],
),
spec=k8s.client.V1ServiceSpec(
selector={"app": self.label_name},
ports=[
k8s.client.V1ServicePort(
protocol=self.protocol,
port=self.port,
target_port=self.targetport,
)
],
type=self.service_type,
),
)

def execute(self):
logging.info("LWSservice class execute")
service = self._build_service()
logging.info("Submitting LWSservice body=%s ", service)
v1 = k8s.client.CoreV1Api()
return v1.create_namespaced_service(namespace=self.config.namespace, body=service)
51 changes: 51 additions & 0 deletions axlearn/cloud/gcp/k8s_service_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Tests k8s service module."""

from absl import flags

from axlearn.cloud.common.utils import define_flags, from_flags
from axlearn.cloud.gcp import k8s_service
from axlearn.common.test_utils import TestCase

FLAGS = flags.FLAGS


class GKELWSService(TestCase):
"""Tests GKEService with LWS(TPU)."""

def _service_config(
self,
*,
command: str,
**kwargs,
) -> k8s_service.LWSService.Config:
fv = flags.FlagValues()
cfg = k8s_service.LWSService.default_config().set()

define_flags(cfg, fv)
print(kwargs)
for key, value in kwargs.items():
print(key, value)
if value is not None:
# Use setattr rather than set_default to set flags.
setattr(fv, key, value)
fv.name = "fake-name"
fv.project = "fake-project"
fv.mark_as_parsed()
cfg = from_flags(cfg, fv, command=command)
# Test that retries are configured on fv by default.
self.assertIsNotNone(fv["name"])
return cfg

def test_instantiate(
self,
):
cfg = self._service_config(
command="test-command",
project="fake-project",
port=9000,
)
self.assertIsInstance(cfg, k8s_service.LWSService.Config)
self.assertEqual(cfg.project, "fake-project")
gke_lws_service = cfg.set().instantiate()
self.assertEqual(cfg.name + "-service", gke_lws_service.name)
self.assertEqual(cfg.port, gke_lws_service.port)
Loading
Loading