Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions plugins/flytekit-ray/flytekitplugins/ray/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,11 @@ def __init__(
self,
ray_start_params: typing.Optional[typing.Dict[str, str]] = None,
k8s_pod: typing.Optional[K8sPod] = None,
enable_ingress: bool = False,
):
self._ray_start_params = ray_start_params
self._k8s_pod = k8s_pod
self._enable_ingress = enable_ingress

@property
def ray_start_params(self):
Expand All @@ -125,13 +127,22 @@ def k8s_pod(self):
"""
return self._k8s_pod

@property
def enable_ingress(self):
"""
Whether to enable an ingress on the head node.
:rtype: bool
"""
return self._enable_ingress

def to_flyte_idl(self):
"""
:rtype: flyteidl.plugins._ray_pb2.HeadGroupSpec
"""
return _ray_pb2.HeadGroupSpec(
ray_start_params=self.ray_start_params if self.ray_start_params else {},
k8s_pod=self.k8s_pod.to_flyte_idl() if self.k8s_pod else None,
enable_ingress=self.enable_ingress,
)

@classmethod
Expand All @@ -143,6 +154,7 @@ def from_flyte_idl(cls, proto):
return cls(
ray_start_params=proto.ray_start_params,
k8s_pod=K8sPod.from_flyte_idl(proto.k8s_pod) if proto.HasField("k8s_pod") else None,
enable_ingress=proto.enable_ingress,
)


Expand Down
2 changes: 2 additions & 0 deletions plugins/flytekit-ray/flytekitplugins/ray/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

@dataclass
class HeadNodeConfig:
enable_ingress: bool = False
ray_start_params: typing.Optional[typing.Dict[str, str]] = None
pod_template: typing.Optional[PodTemplate] = None
requests: Optional[Resources] = None
Expand Down Expand Up @@ -122,6 +123,7 @@ def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]
head_group_spec = HeadGroupSpec(
cfg.head_node_config.ray_start_params,
K8sPod.from_pod_template(head_pod_template) if head_pod_template else None,
enable_ingress=cfg.head_node_config.enable_ingress,
)

worker_group_spec: typing.List[WorkerGroupSpec] = []
Expand Down
4 changes: 2 additions & 2 deletions plugins/flytekit-ray/tests/test_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
pod_template=pod_template,
)
],
head_node_config=HeadNodeConfig(requests=Resources(cpu="1", mem="1Gi"), limits=Resources(cpu="2", mem="2Gi")),
head_node_config=HeadNodeConfig(requests=Resources(cpu="1", mem="1Gi"), limits=Resources(cpu="2", mem="2Gi"), enable_ingress=True),
runtime_env={"pip": ["numpy"]},
enable_autoscaling=True,
shutdown_after_job_finishes=True,
Expand Down Expand Up @@ -84,7 +84,7 @@ def t1(a: int) -> str:
k8s_pod=K8sPod.from_pod_template(pod_template),
)
],
head_group_spec=HeadGroupSpec(k8s_pod=K8sPod.from_pod_template(head_pod_template)),
head_group_spec=HeadGroupSpec(k8s_pod=K8sPod.from_pod_template(head_pod_template), enable_ingress=True),
enable_autoscaling=True,
),
runtime_env=base64.b64encode(json.dumps({"pip": ["numpy"]}).encode()).decode(),
Expand Down
Loading