Skip to content

Commit 73a16c3

Browse files
Implement server-side ypywidgets rendering (#364)
* Implement server-side ypywidgets rendering * Fix types * Use ypywidgets-textual in tests * Update with ypywidgets v0.6.1 and ypywidgets-textual 0.2.1 * Use cell ID instead of cell index in execute API * Add JupyterLab server_side_execution flag * Set shared document file_id
1 parent c1f68a3 commit 73a16c3

File tree

19 files changed

+541
-65
lines changed

19 files changed

+541
-65
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,3 +344,6 @@ $RECYCLE.BIN/
344344
.jupyter_ystore.db
345345
.jupyter_ystore.db-journal
346346
fps_cli_args.toml
347+
348+
# pixi environments
349+
.pixi

jupyverse_api/jupyverse_api/jupyterlab/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,4 @@ async def get_workspace(
8989

9090
class JupyterLabConfig(Config):
9191
dev_mode: bool = False
92+
server_side_execution: bool = False

jupyverse_api/jupyverse_api/kernels/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,4 @@ class Session(BaseModel):
3939

4040
class Execution(BaseModel):
4141
document_id: str
42-
cell_idx: int
42+
cell_id: str

plugins/jupyterlab/fps_jupyterlab/routes.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ async def get_lab(
5858
self.get_index(
5959
"default",
6060
self.frontend_config.collaborative,
61+
self.jupyterlab_config.server_side_execution,
6162
self.jupyterlab_config.dev_mode,
6263
self.frontend_config.base_url,
6364
)
@@ -71,6 +72,7 @@ async def load_workspace(
7172
self.get_index(
7273
"default",
7374
self.frontend_config.collaborative,
75+
self.jupyterlab_config.server_side_execution,
7476
self.jupyterlab_config.dev_mode,
7577
self.frontend_config.base_url,
7678
)
@@ -99,11 +101,12 @@ async def get_workspace(
99101
return self.get_index(
100102
name,
101103
self.frontend_config.collaborative,
104+
self.jupyterlab_config.server_side_execution,
102105
self.jupyterlab_config.dev_mode,
103106
self.frontend_config.base_url,
104107
)
105108

106-
def get_index(self, workspace, collaborative, dev_mode, base_url="/"):
109+
def get_index(self, workspace, collaborative, server_side_execution, dev_mode, base_url="/"):
107110
for path in (self.static_lab_dir).glob("main.*.js"):
108111
main_id = path.name.split(".")[1]
109112
break
@@ -121,6 +124,7 @@ def get_index(self, workspace, collaborative, dev_mode, base_url="/"):
121124
"baseUrl": base_url,
122125
"cacheFiles": False,
123126
"collaborative": collaborative,
127+
"serverSideExecution": server_side_execution,
124128
"devMode": dev_mode,
125129
"disabledExtensions": self.disabled_extension,
126130
"exposeAppInBrowser": False,

plugins/kernels/fps_kernels/kernel_driver/driver.py

Lines changed: 131 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
import uuid
55
from typing import Any, Dict, List, Optional, cast
66

7+
from pycrdt import Array, Map
8+
9+
from jupyverse_api.yjs import Yjs
10+
711
from .connect import cfg_t, connect_channel, launch_kernel, read_connection_file
812
from .connect import write_connection_file as _write_connection_file
913
from .kernelspec import find_kernelspec
@@ -23,10 +27,12 @@ def __init__(
2327
connection_file: str = "",
2428
write_connection_file: bool = True,
2529
capture_kernel_output: bool = True,
30+
yjs: Optional[Yjs] = None,
2631
) -> None:
2732
self.capture_kernel_output = capture_kernel_output
2833
self.kernelspec_path = kernelspec_path or find_kernelspec(kernel_name)
2934
self.kernel_cwd = kernel_cwd
35+
self.yjs = yjs
3036
if not self.kernelspec_path:
3137
raise RuntimeError("Could not find a kernel, maybe you forgot to install one?")
3238
if write_connection_file:
@@ -37,11 +43,12 @@ def __init__(
3743
self.key = cast(str, self.connection_cfg["key"])
3844
self.session_id = uuid.uuid4().hex
3945
self.msg_cnt = 0
40-
self.execute_requests: Dict[str, Dict[str, asyncio.Future]] = {}
41-
self.channel_tasks: List[asyncio.Task] = []
46+
self.execute_requests: Dict[str, Dict[str, asyncio.Queue]] = {}
47+
self.comm_messages: asyncio.Queue = asyncio.Queue()
48+
self.tasks: List[asyncio.Task] = []
4249

4350
async def restart(self, startup_timeout: float = float("inf")) -> None:
44-
for task in self.channel_tasks:
51+
for task in self.tasks:
4552
task.cancel()
4653
msg = create_message("shutdown_request", content={"restart": True})
4754
await send_message(msg, self.control_channel, self.key, change_date_to_str=True)
@@ -52,7 +59,7 @@ async def restart(self, startup_timeout: float = float("inf")) -> None:
5259
if msg["msg_type"] == "shutdown_reply" and msg["content"]["restart"]:
5360
break
5461
await self._wait_for_ready(startup_timeout)
55-
self.channel_tasks = []
62+
self.tasks = []
5663
self.listen_channels()
5764

5865
async def start(self, startup_timeout: float = float("inf"), connect: bool = True) -> None:
@@ -69,6 +76,7 @@ async def connect(self, startup_timeout: float = float("inf")) -> None:
6976
self.connect_channels()
7077
await self._wait_for_ready(startup_timeout)
7178
self.listen_channels()
79+
self.tasks.append(asyncio.create_task(self._handle_comms()))
7280

7381
def connect_channels(self, connection_cfg: Optional[cfg_t] = None):
7482
connection_cfg = connection_cfg or self.connection_cfg
@@ -77,40 +85,43 @@ def connect_channels(self, connection_cfg: Optional[cfg_t] = None):
7785
self.iopub_channel = connect_channel("iopub", connection_cfg)
7886

7987
def listen_channels(self):
80-
self.channel_tasks.append(asyncio.create_task(self.listen_iopub()))
81-
self.channel_tasks.append(asyncio.create_task(self.listen_shell()))
88+
self.tasks.append(asyncio.create_task(self.listen_iopub()))
89+
self.tasks.append(asyncio.create_task(self.listen_shell()))
8290

8391
async def stop(self) -> None:
8492
self.kernel_process.kill()
8593
await self.kernel_process.wait()
8694
os.remove(self.connection_file_path)
87-
for task in self.channel_tasks:
95+
for task in self.tasks:
8896
task.cancel()
8997

9098
async def listen_iopub(self):
9199
while True:
92100
msg = await receive_message(self.iopub_channel, change_str_to_date=True)
93-
msg_id = msg["parent_header"].get("msg_id")
94-
if msg_id in self.execute_requests.keys():
95-
self.execute_requests[msg_id]["iopub_msg"].set_result(msg)
101+
parent_id = msg["parent_header"].get("msg_id")
102+
if msg["msg_type"] in ("comm_open", "comm_msg"):
103+
self.comm_messages.put_nowait(msg)
104+
elif parent_id in self.execute_requests.keys():
105+
self.execute_requests[parent_id]["iopub_msg"].put_nowait(msg)
96106

97107
async def listen_shell(self):
98108
while True:
99109
msg = await receive_message(self.shell_channel, change_str_to_date=True)
100110
msg_id = msg["parent_header"].get("msg_id")
101111
if msg_id in self.execute_requests.keys():
102-
self.execute_requests[msg_id]["shell_msg"].set_result(msg)
112+
self.execute_requests[msg_id]["shell_msg"].put_nowait(msg)
103113

104114
async def execute(
105115
self,
106-
cell: Dict[str, Any],
116+
ycell: Map,
107117
timeout: float = float("inf"),
108118
msg_id: str = "",
109119
wait_for_executed: bool = True,
110120
) -> None:
111-
if cell["cell_type"] != "code":
121+
if ycell["cell_type"] != "code":
112122
return
113-
content = {"code": cell["source"], "silent": False}
123+
ycell["execution_state"] = "busy"
124+
content = {"code": str(ycell["source"]), "silent": False}
114125
msg = create_message(
115126
"execute_request", content, session_id=self.session_id, msg_id=str(self.msg_cnt)
116127
)
@@ -120,40 +131,68 @@ async def execute(
120131
msg_id = msg["header"]["msg_id"]
121132
self.msg_cnt += 1
122133
await send_message(msg, self.shell_channel, self.key, change_date_to_str=True)
134+
self.execute_requests[msg_id] = {
135+
"iopub_msg": asyncio.Queue(),
136+
"shell_msg": asyncio.Queue(),
137+
}
123138
if wait_for_executed:
124139
deadline = time.time() + timeout
125-
self.execute_requests[msg_id] = {
126-
"iopub_msg": asyncio.Future(),
127-
"shell_msg": asyncio.Future(),
128-
}
129140
while True:
130141
try:
131-
await asyncio.wait_for(
132-
self.execute_requests[msg_id]["iopub_msg"],
142+
msg = await asyncio.wait_for(
143+
self.execute_requests[msg_id]["iopub_msg"].get(),
133144
deadline_to_timeout(deadline),
134145
)
135146
except asyncio.TimeoutError:
136147
error_message = f"Kernel didn't respond in {timeout} seconds"
137148
raise RuntimeError(error_message)
138-
msg = self.execute_requests[msg_id]["iopub_msg"].result()
139-
self._handle_outputs(cell["outputs"], msg)
149+
await self._handle_outputs(ycell["outputs"], msg)
140150
if (
141-
msg["header"]["msg_type"] == "status"
142-
and msg["content"]["execution_state"] == "idle"
151+
(msg["header"]["msg_type"] == "status"
152+
and msg["content"]["execution_state"] == "idle")
143153
):
144154
break
145-
self.execute_requests[msg_id]["iopub_msg"] = asyncio.Future()
146155
try:
147-
await asyncio.wait_for(
148-
self.execute_requests[msg_id]["shell_msg"],
156+
msg = await asyncio.wait_for(
157+
self.execute_requests[msg_id]["shell_msg"].get(),
149158
deadline_to_timeout(deadline),
150159
)
151160
except asyncio.TimeoutError:
152161
error_message = f"Kernel didn't respond in {timeout} seconds"
153162
raise RuntimeError(error_message)
154-
msg = self.execute_requests[msg_id]["shell_msg"].result()
155-
cell["execution_count"] = msg["content"]["execution_count"]
163+
with ycell.doc.transaction():
164+
ycell["execution_count"] = msg["content"]["execution_count"]
165+
ycell["execution_state"] = "idle"
156166
del self.execute_requests[msg_id]
167+
else:
168+
self.tasks.append(asyncio.create_task(self._handle_iopub(msg_id, ycell)))
169+
170+
async def _handle_iopub(self, msg_id: str, ycell: Map) -> None:
171+
while True:
172+
msg = await self.execute_requests[msg_id]["iopub_msg"].get()
173+
await self._handle_outputs(ycell["outputs"], msg)
174+
if (
175+
(msg["header"]["msg_type"] == "status"
176+
and msg["content"]["execution_state"] == "idle")
177+
):
178+
msg = await self.execute_requests[msg_id]["shell_msg"].get()
179+
with ycell.doc.transaction():
180+
ycell["execution_count"] = msg["content"]["execution_count"]
181+
ycell["execution_state"] = "idle"
182+
183+
async def _handle_comms(self) -> None:
184+
if self.yjs is None:
185+
return
186+
187+
while True:
188+
msg = await self.comm_messages.get()
189+
msg_type = msg["header"]["msg_type"]
190+
if msg_type == "comm_open":
191+
comm_id = msg["content"]["comm_id"]
192+
comm = Comm(comm_id, self.shell_channel, self.session_id, self.key)
193+
self.yjs.widgets.comm_open(msg, comm) # type: ignore
194+
elif msg_type == "comm_msg":
195+
self.yjs.widgets.comm_msg(msg) # type: ignore
157196

158197
async def _wait_for_ready(self, timeout):
159198
deadline = time.time() + timeout
@@ -178,22 +217,51 @@ async def _wait_for_ready(self, timeout):
178217
break
179218
new_timeout = deadline_to_timeout(deadline)
180219

181-
def _handle_outputs(self, outputs: List[Dict[str, Any]], msg: Dict[str, Any]):
220+
async def _handle_outputs(self, outputs: Array, msg: Dict[str, Any]):
182221
msg_type = msg["header"]["msg_type"]
183222
content = msg["content"]
184223
if msg_type == "stream":
185-
if (not outputs) or (outputs[-1]["name"] != content["name"]):
186-
outputs.append({"name": content["name"], "output_type": msg_type, "text": []})
187-
outputs[-1]["text"].append(content["text"])
224+
with outputs.doc.transaction():
225+
# TODO: uncomment when changes are made in jupyter-ydoc
226+
if (not outputs) or (outputs[-1]["name"] != content["name"]): # type: ignore
227+
outputs.append(
228+
#Map(
229+
# {
230+
# "name": content["name"],
231+
# "output_type": msg_type,
232+
# "text": Array([content["text"]]),
233+
# }
234+
#)
235+
{
236+
"name": content["name"],
237+
"output_type": msg_type,
238+
"text": [content["text"]],
239+
}
240+
)
241+
else:
242+
#outputs[-1]["text"].append(content["text"]) # type: ignore
243+
last_output = outputs[-1]
244+
last_output["text"].append(content["text"]) # type: ignore
245+
outputs[-1] = last_output
188246
elif msg_type in ("display_data", "execute_result"):
189-
outputs.append(
190-
{
191-
"data": {"text/plain": [content["data"].get("text/plain", "")]},
192-
"execution_count": content["execution_count"],
193-
"metadata": {},
194-
"output_type": msg_type,
195-
}
196-
)
247+
if "application/vnd.jupyter.ywidget-view+json" in content["data"]:
248+
# this is a collaborative widget
249+
model_id = content["data"]["application/vnd.jupyter.ywidget-view+json"]["model_id"]
250+
if self.yjs is not None:
251+
if model_id in self.yjs.widgets.widgets: # type: ignore
252+
doc = self.yjs.widgets.widgets[model_id]["model"].ydoc # type: ignore
253+
path = f"ywidget:{doc.guid}"
254+
await self.yjs.room_manager.websocket_server.get_room(path, ydoc=doc) # type: ignore
255+
outputs.append(doc)
256+
else:
257+
outputs.append(
258+
{
259+
"data": {"text/plain": [content["data"].get("text/plain", "")]},
260+
"execution_count": content["execution_count"],
261+
"metadata": {},
262+
"output_type": msg_type,
263+
}
264+
)
197265
elif msg_type == "error":
198266
outputs.append(
199267
{
@@ -203,5 +271,25 @@ def _handle_outputs(self, outputs: List[Dict[str, Any]], msg: Dict[str, Any]):
203271
"traceback": content["traceback"],
204272
}
205273
)
206-
else:
207-
return
274+
275+
276+
class Comm:
277+
def __init__(self, comm_id: str, shell_channel, session_id: str, key: str):
278+
self.comm_id = comm_id
279+
self.shell_channel = shell_channel
280+
self.session_id = session_id
281+
self.key = key
282+
self.msg_cnt = 0
283+
284+
def send(self, buffers):
285+
msg = create_message(
286+
"comm_msg",
287+
content={"comm_id": self.comm_id},
288+
session_id=self.session_id,
289+
msg_id=self.msg_cnt,
290+
buffers=buffers,
291+
)
292+
self.msg_cnt += 1
293+
asyncio.create_task(
294+
send_message(msg, self.shell_channel, self.key, change_date_to_str=True)
295+
)

plugins/kernels/fps_kernels/kernel_driver/message.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def create_message(
5656
content: Dict = {},
5757
session_id: str = "",
5858
msg_id: str = "",
59+
buffers: List = [],
5960
) -> Dict[str, Any]:
6061
header = create_message_header(msg_type, session_id, msg_id)
6162
msg = {
@@ -65,7 +66,7 @@ def create_message(
6566
"parent_header": {},
6667
"content": content,
6768
"metadata": {},
68-
"buffers": [],
69+
"buffers": buffers,
6970
}
7071
return msg
7172

plugins/kernels/fps_kernels/routes.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -259,21 +259,25 @@ async def execute_cell(
259259
execution = Execution(**r)
260260
if kernel_id in kernels:
261261
ynotebook = self.yjs.get_document(execution.document_id)
262-
cell = ynotebook.get_cell(execution.cell_idx)
263-
cell["outputs"] = []
262+
ycells = [ycell for ycell in ynotebook.ycells if ycell["id"] == execution.cell_id]
263+
if not ycells:
264+
return # FIXME
265+
266+
ycell = ycells[0]
267+
del ycell["outputs"][:]
264268

265269
kernel = kernels[kernel_id]
266270
if not kernel["driver"]:
267271
kernel["driver"] = driver = KernelDriver(
268272
kernelspec_path=Path(find_kernelspec(kernel["name"])).as_posix(),
269273
write_connection_file=False,
270274
connection_file=kernel["server"].connection_file_path,
275+
yjs=self.yjs,
271276
)
272277
await driver.connect()
273278
driver = kernel["driver"]
274279

275-
await driver.execute(cell)
276-
ynotebook.set_cell(execution.cell_idx, cell)
280+
await driver.execute(ycell, wait_for_executed=False)
277281

278282
async def get_kernel(
279283
self,

0 commit comments

Comments
 (0)