diff --git a/plugins/flytekit-ray/flytekitplugins/ray/models.py b/plugins/flytekit-ray/flytekitplugins/ray/models.py index 1f3a830f16..aeb2cb890e 100644 --- a/plugins/flytekit-ray/flytekitplugins/ray/models.py +++ b/plugins/flytekit-ray/flytekitplugins/ray/models.py @@ -105,9 +105,11 @@ def __init__( self, ray_start_params: typing.Optional[typing.Dict[str, str]] = None, k8s_pod: typing.Optional[K8sPod] = None, + enable_ingress: bool = False, ): self._ray_start_params = ray_start_params self._k8s_pod = k8s_pod + self._enable_ingress = enable_ingress @property def ray_start_params(self): @@ -125,6 +127,14 @@ def k8s_pod(self): """ return self._k8s_pod + @property + def enable_ingress(self): + """ + Whether to enable an ingress on the head node. + :rtype: bool + """ + return self._enable_ingress + def to_flyte_idl(self): """ :rtype: flyteidl.plugins._ray_pb2.HeadGroupSpec @@ -132,6 +142,7 @@ def to_flyte_idl(self): return _ray_pb2.HeadGroupSpec( ray_start_params=self.ray_start_params if self.ray_start_params else {}, k8s_pod=self.k8s_pod.to_flyte_idl() if self.k8s_pod else None, + enable_ingress=self.enable_ingress, ) @classmethod @@ -143,6 +154,7 @@ def from_flyte_idl(cls, proto): return cls( ray_start_params=proto.ray_start_params, k8s_pod=K8sPod.from_flyte_idl(proto.k8s_pod) if proto.HasField("k8s_pod") else None, + enable_ingress=proto.enable_ingress, ) diff --git a/plugins/flytekit-ray/flytekitplugins/ray/task.py b/plugins/flytekit-ray/flytekitplugins/ray/task.py index 9793e9d5d9..4f6e3ba0b9 100644 --- a/plugins/flytekit-ray/flytekitplugins/ray/task.py +++ b/plugins/flytekit-ray/flytekitplugins/ray/task.py @@ -29,6 +29,7 @@ @dataclass class HeadNodeConfig: + enable_ingress: bool = False ray_start_params: typing.Optional[typing.Dict[str, str]] = None pod_template: typing.Optional[PodTemplate] = None requests: Optional[Resources] = None @@ -122,6 +123,7 @@ def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any] head_group_spec = HeadGroupSpec( cfg.head_node_config.ray_start_params, K8sPod.from_pod_template(head_pod_template) if head_pod_template else None, + enable_ingress=cfg.head_node_config.enable_ingress, ) worker_group_spec: typing.List[WorkerGroupSpec] = [] diff --git a/plugins/flytekit-ray/tests/test_ray.py b/plugins/flytekit-ray/tests/test_ray.py index c9b00a6dad..faf2dc114e 100644 --- a/plugins/flytekit-ray/tests/test_ray.py +++ b/plugins/flytekit-ray/tests/test_ray.py @@ -36,7 +36,7 @@ pod_template=pod_template, ) ], - head_node_config=HeadNodeConfig(requests=Resources(cpu="1", mem="1Gi"), limits=Resources(cpu="2", mem="2Gi")), + head_node_config=HeadNodeConfig(requests=Resources(cpu="1", mem="1Gi"), limits=Resources(cpu="2", mem="2Gi"), enable_ingress=True), runtime_env={"pip": ["numpy"]}, enable_autoscaling=True, shutdown_after_job_finishes=True, @@ -84,7 +84,7 @@ def t1(a: int) -> str: k8s_pod=K8sPod.from_pod_template(pod_template), ) ], - head_group_spec=HeadGroupSpec(k8s_pod=K8sPod.from_pod_template(head_pod_template)), + head_group_spec=HeadGroupSpec(k8s_pod=K8sPod.from_pod_template(head_pod_template), enable_ingress=True), enable_autoscaling=True, ), runtime_env=base64.b64encode(json.dumps({"pip": ["numpy"]}).encode()).decode(),