23
23
24
24
if t .TYPE_CHECKING :
25
25
import jupyter_client
26
+ from nbformat import NotebookNode
26
27
27
28
try :
28
29
import jupyter_server_ydoc
@@ -74,6 +75,14 @@ async def _get_ycell(
74
75
ydoc : jupyter_server_ydoc .app .YDocExtension | None ,
75
76
metadata : dict | None ,
76
77
) -> y .Map | None :
78
+ """Get the cell from which the execution was triggered.
79
+
80
+ Args:
81
+ ydoc: The YDoc jupyter server extension
82
+ metadata: Execution context
83
+ Returns:
84
+ The cell
85
+ """
77
86
if ydoc is None :
78
87
msg = "jupyter-collaboration extension is not installed on the server. Outputs won't be written within the document." # noqa: E501
79
88
get_logger ().warning (msg )
@@ -89,7 +98,7 @@ async def _get_ycell(
89
98
get_logger ().debug (msg )
90
99
return None
91
100
92
- notebook : YNotebook = await ydoc .get_document (room_id = document_id , copy = False )
101
+ notebook : YNotebook | None = await ydoc .get_document (room_id = document_id , copy = False )
93
102
94
103
if notebook is None :
95
104
msg = f"Document with ID { document_id } not found."
@@ -118,7 +127,14 @@ async def _get_ycell(
118
127
return ycell
119
128
120
129
121
- def _output_hook (ycell , outputs , msg ) -> None :
130
+ def _output_hook (outputs : list [NotebookNode ], ycell : y .Map | None , msg : dict ) -> None :
131
+ """Callback on execution request when an output is emitted.
132
+
133
+ Args:
134
+ outputs: A list of previously emitted outputs
135
+ ycell: The cell being executed
136
+ msg: The output message
137
+ """
122
138
msg_type = msg ["header" ]["msg_type" ]
123
139
if msg_type in ("display_data" , "stream" , "execute_result" , "error" ):
124
140
# FIXME support for version
@@ -162,6 +178,13 @@ def _stdin_hook(kernel_id: str, request_id: str, pending_input: PendingInput, ms
162
178
"""Callback on stdin message.
163
179
164
180
It will register the pending input as temporary answer to the execution request.
181
+
182
+ Args:
183
+ kernel_id: The Kernel ID
184
+ request_id: The request ID that triggers the input request
185
+ pending_input: The pending input description.
186
+ This object will be mutated with useful information from ``msg``.
187
+ msg: The stdin msg
165
188
"""
166
189
get_logger ().debug (f"Execution request { kernel_id } received a input request." )
167
190
if PendingInput .request_id is not None :
@@ -184,6 +207,17 @@ async def _execute_snippet(
184
207
metadata : dict | None ,
185
208
stdin_hook : t .Callable [[dict ], None ] | None ,
186
209
) -> dict [str , t .Any ]:
210
+ """Snippet executor
211
+
212
+ Args:
213
+ client: Kernel client
214
+ ydoc: Jupyter server YDoc extension
215
+ snippet: The code snippet to execute
216
+ metadata: The code snippet metadata; e.g. to define the snippet context
217
+ stdin_hook: The stdin message callback
218
+ Returns:
219
+ The execution status and outputs.
220
+ """
187
221
ycell = None
188
222
if metadata is not None :
189
223
ycell = await _get_ycell (ydoc , metadata )
@@ -201,7 +235,7 @@ async def _execute_snippet(
201
235
client .execute_interactive (
202
236
snippet ,
203
237
# FIXME stream partial results
204
- output_hook = partial (_output_hook , ycell , outputs ),
238
+ output_hook = partial (_output_hook , outputs , ycell ),
205
239
stdin_hook = stdin_hook if client .allow_stdin else None ,
206
240
)
207
241
)
@@ -327,21 +361,27 @@ async def cancel(self, kernel_id: str, timeout: float | None = None) -> None:
327
361
328
362
Args:
329
363
kernel_id : Kernel identifier
364
+ timeout: Timeout to await for completion in seconds
365
+
366
+ Raises:
367
+ TimeoutError: if a task is not cancelled in time
330
368
"""
331
369
# FIXME connect this to kernel lifecycle
332
370
get_logger ().debug (f"Cancel execution for kernel { kernel_id } ." )
333
- worker = self .__workers .pop (kernel_id , None )
334
- if worker is not None :
335
- worker .cancel ()
336
- await asyncio .wait_for (worker , timeout = timeout )
337
-
338
- queue = self .__tasks .pop (kernel_id , None )
339
- if queue is not None :
340
- await asyncio .wait_for (queue .join (), timeout = timeout )
341
-
342
- client = self .__kernel_clients .pop (kernel_id , None )
343
- if client is not None :
344
- client .stop_channels ()
371
+ try :
372
+ worker = self .__workers .pop (kernel_id , None )
373
+ if worker is not None :
374
+ worker .cancel ()
375
+ await asyncio .wait_for (worker , timeout = timeout )
376
+ finally :
377
+ try :
378
+ queue = self .__tasks .pop (kernel_id , None )
379
+ if queue is not None :
380
+ await asyncio .wait_for (queue .join (), timeout = timeout )
381
+ finally :
382
+ client = self .__kernel_clients .pop (kernel_id , None )
383
+ if client is not None :
384
+ client .stop_channels ()
345
385
346
386
async def send_input (self , kernel_id : str , value : str ) -> None :
347
387
"""Send input ``value`` to the kernel ``kernel_id``.
@@ -431,6 +471,13 @@ def put(self, kernel_id: str, snippet: str, metadata: dict | None = None) -> str
431
471
return uid
432
472
433
473
def _get_client (self , kernel_id : str ) -> jupyter_client .asynchronous .client .AsyncKernelClient :
474
+ """Get the cached kernel client for ``kernel_id``.
475
+
476
+ Args:
477
+ kernel_id: The kernel ID
478
+ Returns:
479
+ The client for the given kernel.
480
+ """
434
481
if kernel_id not in self .__kernel_clients :
435
482
km = self .__manager .get_kernel (kernel_id )
436
483
self .__kernel_clients [kernel_id ] = km .client ()
0 commit comments