Skip to content

Commit 2da64ef

Browse files
authored
Introduce pod_spec_from_resources()ray helper function (#2943)
* expose requests & limits instead of k8spod Signed-off-by: Jan Fiedler <[email protected]> * put construct_k8s_pod_spec_from_resources into core/resources.py Signed-off-by: Jan Fiedler <[email protected]> * adjust ray tests Signed-off-by: Jan Fiedler <[email protected]> * ruff check fix Signed-off-by: Jan Fiedler <[email protected]> * ruff format Signed-off-by: Jan Fiedler <[email protected]> * remove demo files from PR Signed-off-by: Jan Fiedler <[email protected]> * remove kubernetes from ray plugin dependencies Signed-off-by: Jan Fiedler <[email protected]> * Update structured_dataset.py Signed-off-by: Jan Fiedler <[email protected]> * add tests for construct_k8s_pod_spec_from_resources Signed-off-by: Jan Fiedler <[email protected]> * add kubernetes to pyproject.toml Signed-off-by: Jan Fiedler <[email protected]> * add underscore prefix to _construct_k8s_pods_resources Signed-off-by: Jan Fiedler <[email protected]> * remove parantheses from_flyte_idl Signed-off-by: Jan Fiedler <[email protected]> * expose k8s_gpu_resource_key Signed-off-by: Jan Fiedler <[email protected]> * remove parantheses Signed-off-by: Jan Fiedler <[email protected]> * rename & expose nvidia gpu key Signed-off-by: Jan Fiedler <[email protected]> * adjust resource tests Signed-off-by: Jan Fiedler <[email protected]> * back to exposing k8s pod Signed-off-by: Jan Fiedler <[email protected]> * adjusts tests Signed-off-by: Jan Fiedler <[email protected]> * ruff Signed-off-by: Jan Fiedler <[email protected]> * ruff Signed-off-by: Jan Fiedler <[email protected]> * fix ray tests Signed-off-by: Jan Fiedler <[email protected]> * adjust ray README Signed-off-by: Jan Fiedler <[email protected]> * end of file readme Signed-off-by: Jan Fiedler <[email protected]> * default Resources to None Signed-off-by: Jan Fiedler <[email protected]> * remove optional from k8s_gpu_resource_key Signed-off-by: Jan Fiedler <[email protected]> --------- Signed-off-by: Jan Fiedler <[email protected]>
1 parent f99d50e commit 2da64ef

File tree

6 files changed

+141
-14
lines changed

6 files changed

+141
-14
lines changed

flytekit/core/resources.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from dataclasses import dataclass
2-
from typing import List, Optional, Union
1+
from dataclasses import dataclass, fields
2+
from typing import Any, List, Optional, Union
33

4+
from kubernetes.client import V1Container, V1PodSpec, V1ResourceRequirements
45
from mashumaro.mixins.json import DataClassJSONMixin
56

67
from flytekit.models import task as task_models
@@ -73,7 +74,10 @@ def _convert_resources_to_resource_entries(resources: Resources) -> List[_Resour
7374
resource_entries.append(_ResourceEntry(name=_ResourceName.GPU, value=str(resources.gpu)))
7475
if resources.ephemeral_storage is not None:
7576
resource_entries.append(
76-
_ResourceEntry(name=_ResourceName.EPHEMERAL_STORAGE, value=str(resources.ephemeral_storage))
77+
_ResourceEntry(
78+
name=_ResourceName.EPHEMERAL_STORAGE,
79+
value=str(resources.ephemeral_storage),
80+
)
7781
)
7882
return resource_entries
7983

@@ -96,3 +100,49 @@ def convert_resources_to_resource_model(
96100
if limits is not None:
97101
limit_entries = _convert_resources_to_resource_entries(limits)
98102
return task_models.Resources(requests=request_entries, limits=limit_entries)
103+
104+
105+
def pod_spec_from_resources(
106+
k8s_pod_name: str,
107+
requests: Optional[Resources] = None,
108+
limits: Optional[Resources] = None,
109+
k8s_gpu_resource_key: str = "nvidia.com/gpu",
110+
) -> dict[str, Any]:
111+
def _construct_k8s_pods_resources(resources: Optional[Resources], k8s_gpu_resource_key: str):
112+
if resources is None:
113+
return None
114+
115+
resources_map = {
116+
"cpu": "cpu",
117+
"mem": "memory",
118+
"gpu": k8s_gpu_resource_key,
119+
"ephemeral_storage": "ephemeral-storage",
120+
}
121+
122+
k8s_pod_resources = {}
123+
124+
for resource in fields(resources):
125+
resource_value = getattr(resources, resource.name)
126+
if resource_value is not None:
127+
k8s_pod_resources[resources_map[resource.name]] = resource_value
128+
129+
return k8s_pod_resources
130+
131+
requests = _construct_k8s_pods_resources(resources=requests, k8s_gpu_resource_key=k8s_gpu_resource_key)
132+
limits = _construct_k8s_pods_resources(resources=limits, k8s_gpu_resource_key=k8s_gpu_resource_key)
133+
requests = requests or limits
134+
limits = limits or requests
135+
136+
k8s_pod = V1PodSpec(
137+
containers=[
138+
V1Container(
139+
name=k8s_pod_name,
140+
resources=V1ResourceRequirements(
141+
requests=requests,
142+
limits=limits,
143+
),
144+
)
145+
]
146+
)
147+
148+
return k8s_pod.to_dict()

plugins/flytekit-ray/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,5 @@ To install the plugin, run the following command:
77
```bash
88
pip install flytekitplugins-ray
99
```
10+
11+
All [examples](https://docs.flyte.org/en/latest/flytesnacks/examples/ray_plugin/index.html) showcasing execution of Ray jobs using the plugin can be found in the documentation.

plugins/flytekit-ray/flytekitplugins/ray/task.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,7 @@ def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]
9898
),
9999
worker_group_spec=[
100100
WorkerGroupSpec(
101-
c.group_name,
102-
c.replicas,
103-
c.min_replicas,
104-
c.max_replicas,
105-
c.ray_start_params,
106-
c.k8s_pod,
101+
c.group_name, c.replicas, c.min_replicas, c.max_replicas, c.ray_start_params, c.k8s_pod
107102
)
108103
for c in cfg.worker_node_config
109104
],

plugins/flytekit-ray/tests/test_ray.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,29 @@
44
import ray
55
import yaml
66
from flytekitplugins.ray import HeadNodeConfig
7-
from flytekitplugins.ray.models import RayCluster, RayJob, WorkerGroupSpec, HeadGroupSpec
7+
from flytekitplugins.ray.models import (
8+
HeadGroupSpec,
9+
RayCluster,
10+
RayJob,
11+
WorkerGroupSpec,
12+
)
813
from flytekitplugins.ray.task import RayJobConfig, WorkerNodeConfig
914
from google.protobuf.json_format import MessageToDict
10-
from flytekit.models.task import K8sPod
1115

1216
from flytekit import PythonFunctionTask, task
1317
from flytekit.configuration import Image, ImageConfig, SerializationSettings
18+
from flytekit.models.task import K8sPod
1419

1520
config = RayJobConfig(
16-
worker_node_config=[WorkerNodeConfig(group_name="test_group", replicas=3, min_replicas=0, max_replicas=10, k8s_pod=K8sPod(pod_spec={"str": "worker", "int": 1}))],
21+
worker_node_config=[
22+
WorkerNodeConfig(
23+
group_name="test_group",
24+
replicas=3,
25+
min_replicas=0,
26+
max_replicas=10,
27+
k8s_pod=K8sPod(pod_spec={"str": "worker", "int": 1}),
28+
)
29+
],
1730
head_node_config=HeadNodeConfig(k8s_pod=K8sPod(pod_spec={"str": "head", "int": 2})),
1831
runtime_env={"pip": ["numpy"]},
1932
enable_autoscaling=True,
@@ -44,7 +57,19 @@ def t1(a: int) -> str:
4457
)
4558

4659
ray_job_pb = RayJob(
47-
ray_cluster=RayCluster(worker_group_spec=[WorkerGroupSpec(group_name="test_group", replicas=3, min_replicas=0, max_replicas=10, k8s_pod=K8sPod(pod_spec={"str": "worker", "int": 1}))], head_group_spec=HeadGroupSpec(k8s_pod=K8sPod(pod_spec={"str": "head", "int": 2})), enable_autoscaling=True),
60+
ray_cluster=RayCluster(
61+
worker_group_spec=[
62+
WorkerGroupSpec(
63+
group_name="test_group",
64+
replicas=3,
65+
min_replicas=0,
66+
max_replicas=10,
67+
k8s_pod=K8sPod(pod_spec={"str": "worker", "int": 1}),
68+
)
69+
],
70+
head_group_spec=HeadGroupSpec(k8s_pod=K8sPod(pod_spec={"str": "head", "int": 2})),
71+
enable_autoscaling=True,
72+
),
4873
runtime_env=base64.b64encode(json.dumps({"pip": ["numpy"]}).encode()).decode(),
4974
runtime_env_yaml=yaml.dump({"pip": ["numpy"]}),
5075
shutdown_after_job_finishes=True,

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ dependencies = [
3333
"jsonlines",
3434
"jsonpickle",
3535
"keyring>=18.0.1",
36+
"kubernetes>=12.0.1",
3637
"markdown-it-py",
3738
"marshmallow-enum",
3839
"marshmallow-jsonschema>=0.12.0",

tests/flytekit/unit/core/test_resources.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
from typing import Dict
22

33
import pytest
4+
from kubernetes.client import V1Container, V1PodSpec, V1ResourceRequirements
45

56
import flytekit.models.task as _task_models
67
from flytekit import Resources
7-
from flytekit.core.resources import convert_resources_to_resource_model
8+
from flytekit.core.resources import (
9+
pod_spec_from_resources,
10+
convert_resources_to_resource_model,
11+
)
812

913
_ResourceName = _task_models.Resources.ResourceName
1014

@@ -101,3 +105,53 @@ def test_resources_round_trip():
101105
json_str = original.to_json()
102106
result = Resources.from_json(json_str)
103107
assert original == result
108+
109+
110+
def test_pod_spec_from_resources_requests_limits_set():
111+
requests = Resources(cpu="1", mem="1Gi", gpu="1", ephemeral_storage="1Gi")
112+
limits = Resources(cpu="4", mem="2Gi", gpu="1", ephemeral_storage="1Gi")
113+
k8s_pod_name = "foo"
114+
115+
expected_pod_spec = V1PodSpec(
116+
containers=[
117+
V1Container(
118+
name=k8s_pod_name,
119+
resources=V1ResourceRequirements(
120+
requests={
121+
"cpu": "1",
122+
"memory": "1Gi",
123+
"nvidia.com/gpu": "1",
124+
"ephemeral-storage": "1Gi",
125+
},
126+
limits={
127+
"cpu": "4",
128+
"memory": "2Gi",
129+
"nvidia.com/gpu": "1",
130+
"ephemeral-storage": "1Gi",
131+
},
132+
),
133+
)
134+
]
135+
)
136+
pod_spec = pod_spec_from_resources(k8s_pod_name=k8s_pod_name, requests=requests, limits=limits)
137+
assert expected_pod_spec == V1PodSpec(**pod_spec)
138+
139+
140+
def test_pod_spec_from_resources_requests_set():
141+
requests = Resources(cpu="1", mem="1Gi")
142+
limits = None
143+
k8s_pod_name = "foo"
144+
145+
expected_pod_spec = V1PodSpec(
146+
containers=[
147+
V1Container(
148+
name=k8s_pod_name,
149+
resources=V1ResourceRequirements(
150+
requests={"cpu": "1", "memory": "1Gi"},
151+
limits={"cpu": "1", "memory": "1Gi"},
152+
),
153+
)
154+
]
155+
)
156+
pod_spec = pod_spec_from_resources(k8s_pod_name=k8s_pod_name, requests=requests, limits=limits)
157+
assert expected_pod_spec == V1PodSpec(**pod_spec)

0 commit comments

Comments
 (0)