Skip to content

Commit 8a6bbd0

Browse files
shuyingliangFuture-OutlierthomasjpfanJiangJiaWei1103wild-endeavor
authored
Add the Flyte agent to provision and manage K8s (data) service for deep learning (GNN) use cases (#3004)
Signed-off-by: Shuying Liang <shuying.liang@gmail.com> Signed-off-by: Future-Outlier <eric901201@gmail.com> Signed-off-by: JiaWei Jiang <waynechuang97@gmail.com> Signed-off-by: Yee Hing Tong <wild-endeavor@users.noreply.github.com> Signed-off-by: Niels Bantilan <niels.bantilan@gmail.com> Co-authored-by: Han-Ru Chen (Future-Outlier) <eric901201@gmail.com> Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com> Co-authored-by: 江家瑋 <36886416+JiangJiaWei1103@users.noreply.github.com> Co-authored-by: Yee Hing Tong <wild-endeavor@users.noreply.github.com> Co-authored-by: Niels Bantilan <niels.bantilan@gmail.com>
1 parent e9fa4fb commit 8a6bbd0

File tree

27 files changed

+1562
-9
lines changed

27 files changed

+1562
-9
lines changed

.github/workflows/pythonbuild.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,7 @@ jobs:
303303
- flytekit-huggingface
304304
- flytekit-identity-aware-proxy
305305
- flytekit-inference
306+
- flytekit-k8sdataservice
306307
- flytekit-k8s-pod
307308
- flytekit-kf-mpi
308309
- flytekit-kf-pytorch

.pre-commit-config.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
repos:
22
- repo: https://github.com/astral-sh/ruff-pre-commit
33
# Ruff version.
4-
rev: v0.6.9
4+
rev: v0.8.3
55
hooks:
66
# Run the linter.
77
- id: ruff
@@ -26,5 +26,6 @@ repos:
2626
rev: v2.3.0
2727
hooks:
2828
- id: codespell
29-
additional_dependencies:
30-
- tomli
29+
args:
30+
- --ignore-words-list=assertIn # Ignore 'assertIn'
31+
additional_dependencies: [tomli]

Dockerfile.agent

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ RUN apt-get update && apt-get install build-essential -y \
1111
RUN uv pip install --system --no-cache-dir -U flytekit==$VERSION \
1212
flytekitplugins-airflow==$VERSION \
1313
flytekitplugins-bigquery==$VERSION \
14+
flytekitplugins-k8sdataservice==$VERSION \
1415
flytekitplugins-openai==$VERSION \
1516
flytekitplugins-snowflake==$VERSION \
1617
flytekitplugins-awssagemaker==$VERSION \
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
.. k8sstatefuldataservice:
2+
3+
###################################################
4+
Kubernetes StatefulSet Data Service API reference
5+
###################################################
6+
7+
.. tags:: Integration, DeepLearning, MachineLearning, Kubernetes, GNN
8+
9+
.. automodule:: flytekitplugins.k8sdataservice
10+
:no-members:
11+
:no-inherited-members:
12+
:no-special-members:
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# K8s Stateful Service Plugin
2+
3+
This plugin provides support for Kubernetes StatefulSet and Service integration, enabling seamless provisioning and coordination with any Kubernetes services or Flyte tasks. It is especially suited for deep learning use cases at scale, where distributed and parallelized data loading and caching across nodes are required.
4+
5+
## Features
6+
- **Predictable and Reliable Endpoints**: The service creates consistent endpoints, facilitating communication between services or tasks within the same Kubernetes cluster.
7+
- **Reusable Across Runs**: Service tasks can persist across task runs, ensuring consistency. Alternatively, a cleanup sensor can release cluster resources when they are no longer needed.
8+
- **Conventional Pod Naming**: Pods in the StatefulSet follow a conventional naming pattern. For instance, if the StatefulSet name is `foo` and replicas are set to 2, the pod endpoints will be `foo-0.foo:1234` and `foo-1.foo:1234`. This simplifies endpoint construction for training or inference scripts. For example, gRPC endpoints can directly use `foo-0.foo:1234` and `foo-1.foo:1234`.
9+
10+
## Installation
11+
12+
Install the plugin via pip:
13+
14+
```bash
15+
pip install flytekitplugins-k8sdataservice
16+
```
17+
18+
## Usage
19+
20+
Below is an example demonstrating how to provision and run a service in Kubernetes, making it reachable within the cluster.
21+
22+
**Note**: Utility functions are available to generate unique service names that can be reused across training or inference scripts.
23+
24+
### Example Usage
25+
26+
#### Provisioning a Data Service
27+
```python
28+
from flytekitplugins.k8sdataservice import DataServiceConfig, DataServiceTask, CleanupSensor
29+
from utils.infra import gen_infra_name
30+
from flytekit import kwtypes, Resources, task, workflow
31+
32+
# Generate a unique infrastructure name
33+
name = gen_infra_name()
34+
35+
def k8s_data_service():
36+
gnn_config = DataServiceConfig(
37+
Name=name,
38+
Requests=Resources(cpu='1', mem='1Gi'),
39+
Limits=Resources(cpu='2', mem='2Gi'),
40+
Replicas=1,
41+
Image="busybox:latest",
42+
Command=[
43+
"bash",
44+
"-c",
45+
"echo Hello Flyte K8s Stateful Service! && sleep 3600"
46+
],
47+
)
48+
49+
gnn_task = DataServiceTask(
50+
name="K8s Stateful Data Service",
51+
inputs=kwtypes(ds=str),
52+
task_config=gnn_config,
53+
)
54+
return gnn_task
55+
56+
# Define a cleanup sensor
57+
gnn_sensor = CleanupSensor(name="Cleanup")
58+
59+
# Define a workflow to test the data service
60+
@workflow
61+
def test_dataservice_wf(name: str):
62+
k8s_data_service()(ds="OSS Flyte K8s Data Service Demo") \
63+
>> gnn_sensor(
64+
release_name=name,
65+
cleanup_data_service=True,
66+
)
67+
68+
if __name__ == "__main__":
69+
out = test_dataservice_wf(name="example")
70+
print(f"Running test_dataservice_wf() {out}")
71+
```
72+
73+
#### Accessing the Data Service
74+
Other tasks or services that need to access the data service can do so in multiple ways. For example, using environment variables:
75+
76+
```python
77+
from kubernetes.client import V1PodSpec, V1Container, V1EnvVar
78+
79+
PRIMARY_CONTAINER_NAME = "primary"
80+
FLYTE_POD_SPEC = V1PodSpec(
81+
containers=[
82+
V1Container(
83+
name=PRIMARY_CONTAINER_NAME,
84+
env=[
85+
V1EnvVar(name="MY_DATASERVICES", value=f"{name}-0.{name}:40000 {name}-1.{name}:40000"),
86+
],
87+
)
88+
],
89+
)
90+
91+
task_config = MPIJob(
92+
launcher=Launcher(replicas=1, pod_template=FLYTE_POD_SPEC),
93+
worker=Worker(replicas=1, pod_template=FLYTE_POD_SPEC),
94+
)
95+
96+
@task(task_config=task_config)
97+
def mpi_task() -> str:
98+
return "your script uses the envs to communicate with the data service "
99+
```
100+
101+
### Key Points
102+
- The `DataServiceConfig` defines resource requests, limits, replicas, and the container image/command.
103+
- The `CleanupSensor` ensures resources are cleaned up when required.
104+
- The workflow connects the service provisioning and cleanup process for streamlined operations.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
kubernetes~=23.6.0
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""
2+
.. currentmodule:: flytekitplugins.k8sdataservice
3+
4+
This package contains things that are useful when extending Flytekit.
5+
6+
.. autosummary::
7+
:template: custom.rst
8+
:toctree: generated/
9+
10+
DataServiceTask
11+
"""
12+
13+
from .agent import DataServiceAgent # noqa: F401
14+
from .sensor import CleanupSensor # noqa: F401
15+
from .task import DataServiceConfig, DataServiceTask # noqa: F401
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
from dataclasses import dataclass
2+
from typing import Optional
3+
4+
from flyteidl.core.execution_pb2 import TaskExecution
5+
from flytekitplugins.k8sdataservice.k8s.manager import K8sManager
6+
from flytekitplugins.k8sdataservice.task import DataServiceConfig
7+
8+
from flytekit import logger
9+
from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource, ResourceMeta
10+
from flytekit.models.literals import LiteralMap
11+
from flytekit.models.task import TaskTemplate
12+
13+
14+
@dataclass
15+
class DataServiceMetadata(ResourceMeta):
16+
dataservice_config: DataServiceConfig
17+
name: str
18+
19+
20+
class DataServiceAgent(AsyncAgentBase):
21+
name = "K8s DataService Async Agent"
22+
23+
def __init__(self):
24+
self.k8s_manager = K8sManager()
25+
super().__init__(task_type_name="dataservicetask", metadata_type=DataServiceMetadata)
26+
self.config = None
27+
28+
def create(
29+
self, task_template: TaskTemplate, output_prefix: str, inputs: Optional[LiteralMap] = None, **kwargs
30+
) -> DataServiceMetadata:
31+
graph_engine_config = task_template.custom
32+
self.k8s_manager.set_configs(graph_engine_config)
33+
logger.info(f"Loaded agent config file {self.config}")
34+
existing_release_name = graph_engine_config.get("ExistingReleaseName", None)
35+
logger.info(f"The existing data service release name is {existing_release_name}")
36+
37+
name = ""
38+
if existing_release_name is None or existing_release_name == "":
39+
logger.info("Creating K8s data service resources...")
40+
name = self.k8s_manager.create_data_service()
41+
logger.info(f'Data service {name} with image {graph_engine_config["Image"]} completed')
42+
else:
43+
name = existing_release_name
44+
logger.info(f"User configs to use the existing data service release name: {name}.")
45+
46+
dataservice_config = DataServiceConfig(
47+
Name=graph_engine_config.get("Name", None),
48+
Image=graph_engine_config["Image"],
49+
Command=graph_engine_config["Command"],
50+
Cluster=graph_engine_config["Cluster"],
51+
ExistingReleaseName=graph_engine_config.get("ExistingReleaseName", None),
52+
)
53+
metadata = DataServiceMetadata(
54+
dataservice_config=dataservice_config,
55+
name=name,
56+
)
57+
logger.info(f"Created DataService metadata {metadata}")
58+
return metadata
59+
60+
def get(self, resource_meta: DataServiceMetadata) -> Resource:
61+
logger.info("K8s Data Service get is called")
62+
data = resource_meta.dataservice_config
63+
data_dict = data.__dict__ if isinstance(data, DataServiceConfig) else data
64+
logger.info(f"The data_dict is {data_dict}")
65+
self.k8s_manager.set_configs(data_dict)
66+
name = data.Name
67+
logger.info(f"Get the stateful set name {name}")
68+
69+
k8s_status = self.k8s_manager.check_stateful_set_status(name)
70+
flyte_state = None
71+
if k8s_status in ["failed", "timeout", "timedout", "canceled", "skipped", "internal_error"]:
72+
flyte_state = TaskExecution.FAILED
73+
elif k8s_status in ["done", "succeeded", "success"]:
74+
flyte_state = TaskExecution.SUCCEEDED
75+
elif k8s_status in ["running", "terminating", "pending"]:
76+
flyte_state = TaskExecution.RUNNING
77+
else:
78+
logger.error(f"Unrecognized state: {k8s_status}")
79+
outputs = {
80+
"data_service_name": name,
81+
}
82+
# TODO: Add logs for StatefulSet.
83+
return Resource(phase=flyte_state, outputs=outputs)
84+
85+
def delete(self, resource_meta: DataServiceMetadata):
86+
logger.info("DataService delete is called")
87+
data = resource_meta.dataservice_config
88+
89+
data_dict = data.__dict__ if isinstance(data, DataServiceConfig) else data
90+
self.k8s_manager.set_configs(data_dict)
91+
92+
name = resource_meta.name
93+
logger.info(f"To delete the DataService (e.g., StatefulSet and Service) with name {name}")
94+
self.k8s_manager.delete_stateful_set(name)
95+
self.k8s_manager.delete_service(name)
96+
97+
98+
AgentRegistry.register(DataServiceAgent())

plugins/flytekit-k8sdataservice/flytekitplugins/k8sdataservice/k8s/__init__.py

Whitespace-only changes.
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from kubernetes import config
2+
3+
from flytekit import logger
4+
5+
6+
class KubeConfig:
7+
def __init__(self):
8+
pass
9+
10+
def load_kube_config(self) -> None:
11+
"""Load the kubernetes config based on fabric details prior to K8s client usage
12+
13+
:params target_fabric: fabric on which we are loading configs
14+
"""
15+
try:
16+
logger.info("Attempting to load in-cluster configuration.")
17+
config.load_incluster_config() # This will use the service account credentials
18+
logger.info("Successfully loaded in-cluster configuration using the agent service account.")
19+
except config.ConfigException as e:
20+
logger.warning(f"Failed to load in-cluster configuration. {e}")

0 commit comments

Comments
 (0)