Skip to content

Commit d4c77b5

Browse files
Oliver Hsuchanglan
authored andcommitted
GKE Gateway to service express route to minimize the latency
GitOrigin-RevId: 2bcb8f0c445e67e6285281bef4cad7f90099bf0a
1 parent b4d13c7 commit d4c77b5

File tree

10 files changed

+1254
-10
lines changed

10 files changed

+1254
-10
lines changed

axlearn/cloud/gcp/job.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
from axlearn.cloud.common.utils import generate_job_name, subprocess_run
2424
from axlearn.cloud.gcp.config import default_env_id, default_project, default_zone
2525
from axlearn.cloud.gcp.jobset_utils import BaseReplicatedJob
26+
from axlearn.cloud.gcp.k8s_health_check_policy import LWSHealthCheckPolicy
27+
from axlearn.cloud.gcp.k8s_http_route import LWSHTTPRoute
2628
from axlearn.cloud.gcp.k8s_service import LWSService
2729
from axlearn.cloud.gcp.lws_utils import BaseLeaderWorkerTemplate
2830
from axlearn.cloud.gcp.utils import (
@@ -437,14 +439,22 @@ class Config(GCPJob.Config):
437439
protocol_list: list[str] = None
438440
port_names: list[str] = None
439441
service: Optional[LWSService.Config] = None
442+
gke_gateway_route: bool = False
443+
http_route: Optional[LWSHTTPRoute.Config] = None
444+
health_check_policy: Optional[LWSHealthCheckPolicy.Config] = None
440445

441446
@classmethod
442447
def set_defaults(cls, fv):
443448
super().set_defaults(fv)
444449
fv.set_default("max_tries", fv.max_tries or 10)
445450
fv.set_default("retry_interval", fv.retry_interval or 60)
446451

447-
fv.set_default("enable_service", fv.enable_service or False)
452+
fv.set_default("gke_gateway_route", fv.gke_gateway_route or False)
453+
# When gke_gateway_route is set, enable_service is implicitly True
454+
if fv.gke_gateway_route:
455+
fv.set_default("enable_service", True)
456+
else:
457+
fv.set_default("enable_service", fv.enable_service or False)
448458
fv.set_default("targetports", fv.targetports or ["9090"])
449459
fv.set_default("ports", fv.ports or ["9090"])
450460
fv.set_default("protocol_list", fv.protocol_list or [_ServiceProtocol.TCP.value])
@@ -498,6 +508,12 @@ def define_flags(cls, fv: flags.FlagValues):
498508
"Protocol list needed for different port and targetport combinations",
499509
**common_kwargs,
500510
)
511+
flags.DEFINE_boolean(
512+
"gke_gateway_route",
513+
False,
514+
"Enable gke_gateway_route with notary-proxy sidecars for direct gateway routing",
515+
**common_kwargs,
516+
)
501517

502518
@classmethod
503519
def from_flags(cls, fv: flags.FlagValues, **kwargs):
@@ -509,6 +525,7 @@ def from_flags(cls, fv: flags.FlagValues, **kwargs):
509525
cfg.protocol_list = fv.protocol_list
510526
cfg.port_names = fv.port_names
511527
cfg.service_type = fv.service_type
528+
cfg.gke_gateway_route = fv.gke_gateway_route
512529
return cfg
513530

514531
def __init__(self, cfg: Config, *, bundler: BaseDockerBundler):
@@ -521,6 +538,14 @@ def __init__(self, cfg: Config, *, bundler: BaseDockerBundler):
521538
# required to run the job.
522539
self._builder: BaseLeaderWorkerTemplate = cfg.builder.instantiate(bundler=bundler)
523540

541+
# Wire gke_gateway_route flag to service and http_route configs
542+
if cfg.service is not None:
543+
cfg.service.set(name=cfg.name, gke_gateway_route=cfg.gke_gateway_route)
544+
if cfg.http_route is not None:
545+
cfg.http_route.set(name=cfg.name, namespace=cfg.namespace)
546+
if cfg.health_check_policy is not None:
547+
cfg.health_check_policy.set(name=cfg.name, namespace=cfg.namespace)
548+
524549
def _delete(self):
525550
cfg: GKELeaderWorkerSet.Config = self.config
526551
# Issues a delete request for the LeaderWorkerSet and proactively delete its descendants.
@@ -570,6 +595,16 @@ def _execute(self):
570595
else:
571596
cfg.service = None
572597

598+
#### Creating HTTPRoute for gke_gateway_route #######
599+
if cfg.gke_gateway_route and cfg.http_route:
600+
http_route_resp = cfg.http_route.instantiate().execute()
601+
logging.info("HTTPRoute created %s", str(http_route_resp))
602+
603+
#### Creating HealthCheckPolicy for gke_gateway_route #######
604+
if cfg.gke_gateway_route and cfg.health_check_policy:
605+
health_check_resp = cfg.health_check_policy.instantiate().execute()
606+
logging.info("HealthCheckPolicy created %s", str(health_check_resp))
607+
573608
return lws_resp
574609

575610

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
# Copyright © 2025 Apple Inc.
2+
3+
"""k8s HealthCheckPolicy module for gke_gateway_route feature."""
4+
5+
import copy
6+
import logging
7+
from typing import Any, Optional
8+
9+
import kubernetes as k8s
10+
from absl import flags
11+
12+
from axlearn.cloud.common.utils import FlagConfigurable, generate_job_name
13+
from axlearn.cloud.gcp.config import default_project
14+
from axlearn.cloud.gcp.pathways_utils import NOTARY_PROXY_HTTP_PORT
15+
from axlearn.common.config import REQUIRED, Required, config_class
16+
from axlearn.common.utils import Nested
17+
18+
19+
class LWSHealthCheckPolicy(FlagConfigurable):
20+
"""LWS HealthCheckPolicy for gke_gateway_route feature.
21+
22+
Creates a HealthCheckPolicy K8s object that configures health checks
23+
for the LWS service when using GKE Gateway routing.
24+
"""
25+
26+
@config_class
27+
class Config(FlagConfigurable.Config):
28+
"""Configures LWSHealthCheckPolicy.
29+
30+
Attributes:
31+
name: The name of the LWS resource.
32+
project: The GCP project.
33+
namespace: The namespace of the service.
34+
check_interval_sec: Interval between health checks in seconds.
35+
timeout_sec: Timeout for each health check in seconds.
36+
healthy_threshold: Number of consecutive successes to mark healthy.
37+
unhealthy_threshold: Number of consecutive failures to mark unhealthy.
38+
health_check_port: The port to use for TCP health check.
39+
"""
40+
41+
name: Required[str] = REQUIRED
42+
project: Required[str] = REQUIRED
43+
namespace: str = "default"
44+
check_interval_sec: int = 10
45+
timeout_sec: int = 5
46+
healthy_threshold: int = 1
47+
unhealthy_threshold: int = 3
48+
health_check_port: Optional[int] = None
49+
50+
@classmethod
51+
def define_flags(cls, fv: flags.FlagValues):
52+
common_kwargs = dict(flag_values=fv, allow_override=True)
53+
flags.DEFINE_string("name", None, "Name of the HealthCheckPolicy.", **common_kwargs)
54+
flags.DEFINE_string("project", None, "The GCP project name.", **common_kwargs)
55+
flags.DEFINE_integer(
56+
"health_check_interval_sec",
57+
10,
58+
"Interval between health checks in seconds.",
59+
**common_kwargs,
60+
)
61+
flags.DEFINE_integer(
62+
"health_check_timeout_sec",
63+
5,
64+
"Timeout for each health check in seconds.",
65+
**common_kwargs,
66+
)
67+
flags.DEFINE_integer(
68+
"health_check_healthy_threshold",
69+
1,
70+
"Number of consecutive successes to mark healthy.",
71+
**common_kwargs,
72+
)
73+
flags.DEFINE_integer(
74+
"health_check_unhealthy_threshold",
75+
3,
76+
"Number of consecutive failures to mark unhealthy.",
77+
**common_kwargs,
78+
)
79+
flags.DEFINE_integer(
80+
"health_check_port",
81+
None,
82+
"Port to use for TCP health check. Defaults to NOTARY_PROXY_HTTP_PORT.",
83+
**common_kwargs,
84+
)
85+
86+
@classmethod
87+
def set_defaults(cls, fv: flags.FlagValues):
88+
fv.set_default("name", fv.name or generate_job_name())
89+
fv.set_default("project", default_project())
90+
fv.set_default("health_check_interval_sec", fv.health_check_interval_sec or 10)
91+
fv.set_default("health_check_timeout_sec", fv.health_check_timeout_sec or 5)
92+
fv.set_default("health_check_healthy_threshold", fv.health_check_healthy_threshold or 1)
93+
fv.set_default("health_check_unhealthy_threshold", fv.health_check_unhealthy_threshold or 3)
94+
95+
@classmethod
96+
def from_flags(cls, fv: flags.FlagValues, **kwargs):
97+
cfg: LWSHealthCheckPolicy.Config = super().from_flags(fv, **kwargs)
98+
cfg.name = fv.name
99+
cfg.check_interval_sec = fv.health_check_interval_sec
100+
cfg.timeout_sec = fv.health_check_timeout_sec
101+
cfg.healthy_threshold = fv.health_check_healthy_threshold
102+
cfg.unhealthy_threshold = fv.health_check_unhealthy_threshold
103+
if hasattr(fv, "health_check_port") and fv.health_check_port:
104+
cfg.health_check_port = fv.health_check_port
105+
return cfg
106+
107+
@classmethod
108+
def default_config(cls):
109+
return super().default_config()
110+
111+
def __init__(self, cfg: Config):
112+
super().__init__(cfg)
113+
logging.info("LWSHealthCheckPolicy class init")
114+
self._config = copy.deepcopy(cfg)
115+
self.name = cfg.name
116+
self.service_name = f"{cfg.name}-service"
117+
# Default to NOTARY_PROXY_HTTP_PORT if not specified
118+
self.health_check_port = cfg.health_check_port or NOTARY_PROXY_HTTP_PORT
119+
120+
def _build_health_check_policy(self) -> Nested[Any]:
121+
"""Builds a config for a HealthCheckPolicy.
122+
123+
Returns:
124+
A nested dict corresponding to a K8s HealthCheckPolicy config.
125+
"""
126+
cfg = self.config
127+
logging.info("LWSHealthCheckPolicy class build")
128+
logging.info(str(self.config))
129+
130+
# Import utils here to avoid circular dependency
131+
from axlearn.cloud.gcp.utils import ( # pylint: disable=import-outside-toplevel
132+
custom_leaderworkerset_kwargs,
133+
)
134+
135+
# Fetch the LeaderWorkerSet to get its UID for owner reference
136+
api_kwargs = custom_leaderworkerset_kwargs()
137+
custom_api = k8s.client.CustomObjectsApi()
138+
lws = custom_api.get_namespaced_custom_object(
139+
group=api_kwargs["group"],
140+
version=api_kwargs["version"],
141+
namespace=cfg.namespace,
142+
plural=api_kwargs["plural"],
143+
name=self.name,
144+
)
145+
146+
# Build the HealthCheckPolicy spec
147+
health_check_policy = {
148+
"apiVersion": "networking.gke.io/v1",
149+
"kind": "HealthCheckPolicy",
150+
"metadata": {
151+
"name": f"{self.name}-health-check",
152+
"namespace": cfg.namespace,
153+
"ownerReferences": [
154+
{
155+
"apiVersion": f'{api_kwargs["group"]}/{api_kwargs["version"]}',
156+
"kind": "LeaderWorkerSet",
157+
"name": self.name,
158+
"uid": lws["metadata"]["uid"],
159+
"controller": True,
160+
"blockOwnerDeletion": True,
161+
}
162+
],
163+
},
164+
"spec": {
165+
"default": {
166+
"checkIntervalSec": cfg.check_interval_sec,
167+
"timeoutSec": cfg.timeout_sec,
168+
"healthyThreshold": cfg.healthy_threshold,
169+
"unhealthyThreshold": cfg.unhealthy_threshold,
170+
"config": {
171+
"type": "TCP",
172+
"tcpHealthCheck": {
173+
"port": self.health_check_port,
174+
},
175+
},
176+
},
177+
"targetRef": {
178+
"group": "",
179+
"kind": "Service",
180+
"name": self.service_name,
181+
},
182+
},
183+
}
184+
185+
return health_check_policy
186+
187+
def execute(self):
188+
"""Creates the HealthCheckPolicy in the cluster."""
189+
logging.info("LWSHealthCheckPolicy class execute")
190+
health_check_policy = self._build_health_check_policy()
191+
logging.info("Submitting LWSHealthCheckPolicy body=%s", health_check_policy)
192+
193+
return k8s.client.CustomObjectsApi().create_namespaced_custom_object(
194+
group="networking.gke.io",
195+
version="v1",
196+
namespace=self.config.namespace,
197+
plural="healthcheckpolicies",
198+
body=health_check_policy,
199+
)

0 commit comments

Comments
 (0)