Skip to content

Commit 1f3d63f

Browse files
committed
[WIP] Interactive Ray support
Add support for creating remote kernels via Ray operator by introducing a RayOperatorProcessProxy
1 parent 0b2af8d commit 1f3d63f

File tree

19 files changed

+573
-19
lines changed

19 files changed

+573
-19
lines changed

Makefile

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33

44
.PHONY: help clean clean-env dev dev-http docs install bdist sdist test release check_dists \
55
clean-images clean-enterprise-gateway clean-demo-base clean-kernel-images clean-enterprise-gateway \
6-
clean-kernel-py clean-kernel-spark-py clean-kernel-r clean-kernel-spark-r clean-kernel-scala clean-kernel-tf-py \
6+
clean-kernel-py clean-kernel-spark-py clean-kernel-ray-py clean-kernel-r clean-kernel-spark-r clean-kernel-scala clean-kernel-tf-py \
77
clean-kernel-tf-gpu-py clean-kernel-image-puller push-images push-enterprise-gateway-demo push-demo-base \
8-
push-kernel-images push-enterprise-gateway push-kernel-py push-kernel-spark-py push-kernel-r push-kernel-spark-r \
8+
push-kernel-images push-enterprise-gateway push-kernel-py push-kernel-spark-py push-kernel-ray-py push-kernel-r push-kernel-spark-r \
99
push-kernel-scala push-kernel-tf-py push-kernel-tf-gpu-py push-kernel-image-puller publish helm-chart
1010

1111
SA?=source activate
@@ -155,9 +155,9 @@ docker-images: ## Build docker images (includes kernel-based images)
155155
kernel-images: ## Build kernel-based docker images
156156

157157
# Actual working targets...
158-
docker-images: demo-base enterprise-gateway-demo kernel-images enterprise-gateway kernel-py kernel-spark-py kernel-r kernel-spark-r kernel-scala kernel-tf-py kernel-tf-gpu-py kernel-image-puller
158+
docker-images: demo-base enterprise-gateway-demo kernel-images enterprise-gateway kernel-py kernel-spark-py kernel-ray-py kernel-r kernel-spark-r kernel-scala kernel-tf-py kernel-tf-gpu-py kernel-image-puller
159159

160-
enterprise-gateway-demo kernel-images enterprise-gateway kernel-py kernel-spark-py kernel-r kernel-spark-r kernel-scala kernel-tf-py kernel-tf-gpu-py kernel-image-puller:
160+
enterprise-gateway-demo kernel-images enterprise-gateway kernel-py kernel-spark-py kernel-ray-py kernel-r kernel-spark-r kernel-scala kernel-tf-py kernel-tf-gpu-py kernel-image-puller:
161161
make WHEEL_FILE=$(WHEEL_FILE) VERSION=$(VERSION) NO_CACHE=$(NO_CACHE) TAG=$(TAG) SPARK_VERSION=$(SPARK_VERSION) MULTIARCH_BUILD=$(MULTIARCH_BUILD) TARGET_ARCH=$(TARGET_ARCH) -C etc $@
162162

163163
demo-base:
@@ -167,14 +167,14 @@ demo-base:
167167
clean-images: clean-demo-base ## Remove docker images (includes kernel-based images)
168168
clean-kernel-images: ## Remove kernel-based images
169169

170-
clean-images clean-enterprise-gateway-demo clean-kernel-images clean-enterprise-gateway clean-kernel-py clean-kernel-spark-py clean-kernel-r clean-kernel-spark-r clean-kernel-scala clean-kernel-tf-py clean-kernel-tf-gpu-py clean-kernel-image-puller:
170+
clean-images clean-enterprise-gateway-demo clean-kernel-images clean-enterprise-gateway clean-kernel-py clean-kernel-spark-py clean-kernel-ray-py clean-kernel-r clean-kernel-spark-r clean-kernel-scala clean-kernel-tf-py clean-kernel-tf-gpu-py clean-kernel-image-puller:
171171
make WHEEL_FILE=$(WHEEL_FILE) VERSION=$(VERSION) TAG=$(TAG) -C etc $@
172172

173173
clean-demo-base:
174174
make WHEEL_FILE=$(WHEEL_FILE) VERSION=$(VERSION) TAG=$(SPARK_VERSION) -C etc $@
175175

176176
push-images: push-demo-base
177-
push-images push-enterprise-gateway-demo push-kernel-images push-enterprise-gateway push-kernel-py push-kernel-spark-py push-kernel-r push-kernel-spark-r push-kernel-scala push-kernel-tf-py push-kernel-tf-gpu-py push-kernel-image-puller:
177+
push-images push-enterprise-gateway-demo push-kernel-images push-enterprise-gateway push-kernel-py push-kernel-spark-py push-kernel-ray-py push-kernel-r push-kernel-spark-r push-kernel-scala push-kernel-tf-py push-kernel-tf-gpu-py push-kernel-image-puller:
178178
make WHEEL_FILE=$(WHEEL_FILE) VERSION=$(VERSION) TAG=$(TAG) -C etc $@
179179

180180
push-demo-base:

enterprise_gateway/services/processproxies/container.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ def poll(self) -> bool | None:
147147
# See https://github.com/jupyter-server/enterprise_gateway/issues/827
148148
if container_status in self.get_initial_states():
149149
result = None
150+
151+
self.log.debug(f">>> container.poll(): {container_status} --> {result}")
150152
return result
151153

152154
def send_signal(self, signum: int) -> bool | None:
@@ -188,6 +190,7 @@ def shutdown_listener(self):
188190

189191
async def confirm_remote_startup(self) -> None:
190192
"""Confirms the container has started and returned necessary connection information."""
193+
self.log.debug(">>> container.confirm_remote_startup()")
191194
self.log.debug("Trying to confirm kernel container startup status")
192195
self.start_time = RemoteProcessProxy.get_current_time()
193196
i = 0
@@ -197,21 +200,34 @@ async def confirm_remote_startup(self) -> None:
197200
await self.handle_timeout()
198201

199202
container_status = self.get_container_status(i)
203+
self.log.debug(
204+
f">>> container.confirm_remote_startup() - container_status: {container_status}"
205+
)
200206
if container_status:
201207
if container_status in self.get_error_states():
202208
self.log_and_raise(
203209
http_status_code=500,
204210
reason=f"Error starting kernel container; status: '{container_status}'.",
205211
)
206212
else:
213+
self.log.debug(
214+
f">>> container.confirm_remote_startup(): is hosted assigned => {self.assigned_host}"
215+
)
216+
self.log.debug(">>> should call receive_connection_info()")
207217
if self.assigned_host:
208218
ready_to_connect = await self.receive_connection_info()
219+
self.log.debug(
220+
f">>> container.confirm_remote_startup(): ready to connect => {ready_to_connect}"
221+
)
209222
self.pid = (
210223
0 # We won't send process signals for kubernetes lifecycle management
211224
)
212225
self.pgid = 0
213226
else:
214227
self.detect_launch_failure()
228+
self.log.debug(
229+
f">>> container.confirm_remote_startup(): ready to connect => {ready_to_connect}"
230+
)
215231

216232
def get_process_info(self) -> dict[str, Any]:
217233
"""Captures the base information necessary for kernel persistence relative to containers."""

enterprise_gateway/services/processproxies/crd.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,15 @@ def get_container_status(self, iteration: int | None) -> str:
7474
)
7575

7676
if custom_resource:
77-
application_state = custom_resource['status']['applicationState']['state'].lower()
77+
application_state = custom_resource.get("status", {}).get("state", "").lower()
78+
79+
self.log.debug(f">>> crd.get_container_status: {application_state}")
7880

7981
if application_state in self.get_error_states():
8082
exception_text = self._get_exception_text(
81-
custom_resource['status']['applicationState']['errorMessage']
83+
custom_resource.get("status", {})
84+
.get("applicationState", {})
85+
.get("errorMessage")
8286
)
8387
error_message = (
8488
f"CRD submission for kernel {self.kernel_id} failed: {exception_text}"

enterprise_gateway/services/processproxies/k8s.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def get_container_status(self, iteration: int | None) -> str:
115115
self.container_name = pod_info.metadata.name
116116
if pod_info.status:
117117
pod_status = pod_info.status.phase.lower()
118+
self.log.debug(f">>> k8s.get_container_status: {pod_status}")
118119
if pod_status == "running" and not self.assigned_host:
119120
# Pod is running, capture IP
120121
self.assigned_ip = pod_info.status.pod_ip
@@ -128,6 +129,7 @@ def get_container_status(self, iteration: int | None) -> str:
128129
f"Status: '{pod_status}', Pod IP: '{self.assigned_ip}', KernelID: '{self.kernel_id}'"
129130
)
130131

132+
self.log.debug(f">>> k8s.get_container_status: {pod_status}")
131133
return pod_status
132134

133135
def delete_managed_object(self, termination_stati: list[str]) -> bool:

enterprise_gateway/services/processproxies/processproxy.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ def register_event(self, kernel_id: str) -> None:
201201

202202
async def get_connection_info(self, kernel_id: str) -> dict:
203203
"""Performs a timeout wait on the event, returning the conenction information on completion."""
204+
self.log.debug(f">>> processproxy.get_connection_info() for kernel_id {kernel_id}")
204205
await asyncio.wait_for(self._response_registry[kernel_id].wait(), connection_interval)
205206
return self._response_registry.pop(kernel_id).response
206207

@@ -1300,9 +1301,13 @@ async def receive_connection_info(self) -> bool:
13001301
"""
13011302
# Polls the socket using accept. When data is found, returns ready indicator and encrypted data.
13021303
ready_to_connect = False
1303-
1304+
self.log.debug(
1305+
f">>> processproxy.receive_connection_info(): initializing ready to connect as {ready_to_connect}"
1306+
)
13041307
try:
13051308
connect_info = await self.response_manager.get_connection_info(self.kernel_id)
1309+
self.log.debug(">>> processproxy.receive_connection_info(): connect info received")
1310+
self.log.debug(connect_info)
13061311
self._setup_connection_info(connect_info)
13071312
ready_to_connect = True
13081313
except Exception as e:
@@ -1320,6 +1325,9 @@ async def receive_connection_info(self) -> bool:
13201325
self.kill()
13211326
self.log_and_raise(http_status_code=500, reason=error_message)
13221327

1328+
self.log.debug(
1329+
f">>> processproxy.receive_connection_info(): returning ready to connect {ready_to_connect}"
1330+
)
13231331
return ready_to_connect
13241332

13251333
def _setup_connection_info(self, connect_info: dict) -> None:
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
"""A Ray operator process proxy."""
2+
3+
# Internal implementation at Apple
4+
from __future__ import annotations
5+
6+
from typing import Any
7+
8+
from kubernetes import client
9+
10+
from ..kernels.remotemanager import RemoteKernelManager
11+
from .k8s import KubernetesProcessProxy
12+
13+
14+
class RayOperatorProcessProxy(KubernetesProcessProxy):
15+
"""Ray operator process proxy."""
16+
17+
object_kind = "RayCluster"
18+
19+
def __init__(self, kernel_manager: RemoteKernelManager, proxy_config: dict):
20+
"""Initialize the proxy."""
21+
super().__init__(kernel_manager, proxy_config)
22+
self.group = "ray.io"
23+
self.version = "v1alpha1"
24+
self.plural = "rayclusters"
25+
26+
async def launch_process(
27+
self, kernel_cmd: str, **kwargs: dict[str, Any] | None
28+
) -> RayOperatorProcessProxy:
29+
"""Launch the process for a kernel."""
30+
self.kernel_resource_name = self._determine_kernel_pod_name(**kwargs)
31+
kwargs["env"]["KERNEL_RESOURCE_NAME"] = self.kernel_resource_name
32+
kwargs["env"]["KERNEL_CRD_GROUP"] = self.group
33+
kwargs["env"]["KERNEL_CRD_VERSION"] = self.version
34+
kwargs["env"]["KERNEL_CRD_PLURAL"] = self.plural
35+
36+
await super().launch_process(kernel_cmd, **kwargs)
37+
return self
38+
39+
def get_container_status(self, iteration: int | None) -> str:
40+
"""Determines submitted Ray application status and returns unified pod state.
41+
42+
This method returns the pod status (not CRD status) to maintain compatibility
43+
with the base class lifecycle management. The RayCluster CRD state is checked
44+
first to ensure the cluster is healthy, but we return pod states that the
45+
base class understands: 'pending', 'running', 'failed', etc.
46+
"""
47+
application_state = None
48+
head_pod_status = None
49+
application_state = self._get_application_state()
50+
if application_state:
51+
self.log.debug(
52+
f">>> ray_operator.get_container_status: application_state {application_state}"
53+
)
54+
55+
# Check for CRD-level errors first
56+
if application_state in self.get_error_states():
57+
error_message = (
58+
f"CRD submission for kernel {self.kernel_id} failed with state: {application_state}"
59+
)
60+
self.log.error(error_message)
61+
return "failed" # Return pod state, not CRD state
62+
63+
# If CRD is not ready yet, return "pending" to indicate still launching
64+
if application_state != "ready":
65+
self.log.debug(
66+
f">>> ray_operator.get_container_status: CRD not ready yet, state={application_state}"
67+
)
68+
return "pending"
69+
70+
# CRD is ready, now check the actual pod status
71+
kernel_label_selector = "kernel_id=" + self.kernel_id + ",component=kernel"
72+
ret = None
73+
try:
74+
ret = client.CoreV1Api().list_namespaced_pod(
75+
namespace=self.kernel_namespace, label_selector=kernel_label_selector
76+
)
77+
except client.rest.ApiException as e:
78+
if e.status == 404:
79+
self.log.debug("Resetting cluster connection info as cluster deleted")
80+
self._reset_connection_info()
81+
return None
82+
83+
if ret and ret.items:
84+
pod_info = ret.items[0]
85+
self.log.debug(
86+
f"Cluster status {application_state}, pod status {pod_info.status.phase.lower()}"
87+
)
88+
if pod_info.status:
89+
head_pod_status = pod_info.status.phase.lower()
90+
self.log.debug(
91+
f">>> ray_operator.get_container_status: pod_status {head_pod_status}"
92+
)
93+
if head_pod_status == "running":
94+
self.log.debug(
95+
f"Pod Info name:{pod_info.metadata.name}, pod ip {pod_info.status.pod_ip}, host {self.container_name}"
96+
)
97+
self.container_name = pod_info.metadata.name
98+
self.assigned_ip = pod_info.status.pod_ip
99+
self.assigned_host = self.container_name
100+
self.assigned_node_ip = pod_info.status.host_ip
101+
102+
# only log if iteration is not None (otherwise poll() is too noisy)
103+
# check for running state to avoid double logging with superclass
104+
if iteration and head_pod_status != 'running':
105+
self.log.debug(
106+
f"{iteration}: Waiting from CRD status from resource manager {self.object_kind.lower()} in "
107+
f"namespace '{self.kernel_namespace}'. Name: '{self.kernel_resource_name}', "
108+
f"Status: CRD='{application_state}', Pod='{head_pod_status}', KernelID: '{self.kernel_id}'"
109+
)
110+
111+
# KEY FIX: Return pod status (not CRD state) so base class poll() works correctly
112+
final_status = head_pod_status if head_pod_status else "pending"
113+
self.log.debug(
114+
f">>> ray_operator.get_container_status: returning pod_status={final_status} "
115+
f"(CRD state was {application_state})"
116+
)
117+
return final_status
118+
119+
def delete_managed_object(self, termination_stati: list[str]) -> bool:
120+
"""Deletes the object managed by this process-proxy
121+
122+
A return value of True indicates the object is considered deleted,
123+
otherwise a False or None value is returned.
124+
125+
Note: the caller is responsible for handling exceptions.
126+
"""
127+
delete_status = client.CustomObjectsApi().delete_namespaced_custom_object(
128+
self.group,
129+
self.version,
130+
self.kernel_namespace,
131+
self.plural,
132+
self.kernel_resource_name,
133+
grace_period_seconds=0,
134+
propagation_policy="Background",
135+
)
136+
137+
result = delete_status and delete_status.get("status", None) in termination_stati
138+
if result:
139+
self._reset_connection_info()
140+
return result
141+
142+
def get_initial_states(self) -> set:
143+
"""Return list of states indicating container is starting (includes running).
144+
145+
Note: We return pod states (not CRD states) to maintain compatibility
146+
with the base class poll() implementation, which checks if the status
147+
returned by get_container_status() is in this set.
148+
"""
149+
return ["pending", "running"]
150+
151+
def get_error_states(self) -> set:
152+
"""Return list of states indicating RayCluster has failed."""
153+
# Ray doesn't typically use "failed" state, but we'll include common error states
154+
return {"failed", "error", "unhealthy"}
155+
156+
def _get_ray_cluster_status(self) -> dict:
157+
try:
158+
return client.CustomObjectsApi().get_namespaced_custom_object(
159+
self.group,
160+
self.version,
161+
self.kernel_namespace,
162+
self.plural,
163+
self.kernel_resource_name,
164+
)
165+
except client.rest.ApiException as e:
166+
if e.status == 404:
167+
self.log.debug("Resetting cluster connection info as cluster deleted")
168+
self._reset_connection_info()
169+
return None
170+
171+
def _get_application_state(self):
172+
custom_resource = self._get_ray_cluster_status()
173+
174+
if custom_resource is None:
175+
return None
176+
177+
if 'status' not in custom_resource or 'state' not in custom_resource['status']:
178+
return None
179+
180+
return custom_resource['status']['state'].lower()
181+
182+
def _get_pod_status(self) -> str:
183+
"""Get the current status of the kernel pod.
184+
Returns
185+
-------
186+
str
187+
The pod status in lowercase (e.g., 'pending', 'running', 'failed', 'unknown').
188+
"""
189+
pod_status = "unknown"
190+
kernel_label_selector = "kernel_id=" + self.kernel_id + ",component=kernel"
191+
ret = client.CoreV1Api().list_namespaced_pod(
192+
namespace=self.kernel_namespace, label_selector=kernel_label_selector
193+
)
194+
if ret and ret.items:
195+
pod_info = ret.items[0]
196+
self.container_name = pod_info.metadata.name
197+
if pod_info.status:
198+
pod_status = pod_info.status.phase.lower()
199+
self.log.debug(f">>> k8s._get_pod_status: {pod_status}")
200+
201+
return pod_status
202+
203+
def _reset_connection_info(self):
204+
"""Reset all connection-related attributes to their initial state.
205+
This is typically called when a cluster is deleted or connection is lost.
206+
"""
207+
208+
self.assigned_host = None
209+
self.container_name = ""
210+
self.assigned_node_ip = None
211+
self.assigned_ip = None

enterprise_gateway/services/sessions/kernelsessionmanager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def create_session(self, kernel_id: str, **kwargs) -> None:
9494
Information used for the launch of the kernel
9595
9696
"""
97+
self.log.debug(f">>> Creating new session for kernel {kernel_id}")
9798
km = self.kernel_manager.get_kernel(kernel_id)
9899

99100
# Compose the kernel_session entry
@@ -103,11 +104,14 @@ def create_session(self, kernel_id: str, **kwargs) -> None:
103104
kernel_session["kernel_name"] = km.kernel_name
104105

105106
# Build the inner dictionaries: connection_info, process_proxy and add to kernel_session
107+
self.log.debug(f">>> Getting connection info for kernel {kernel_id}")
106108
kernel_session["connection_info"] = km.get_connection_info()
107109
kernel_session["launch_args"] = kwargs.copy()
110+
self.log.debug(f">>> Getting process info for kernel {kernel_id}")
108111
kernel_session["process_info"] = (
109112
km.process_proxy.get_process_info() if km.process_proxy else {}
110113
)
114+
self.log.debug(f">>> Saving session {kernel_session}")
111115
self._save_session(kernel_id, kernel_session)
112116

113117
def refresh_session(self, kernel_id: str) -> None:

0 commit comments

Comments
 (0)