7
7
import traceback
8
8
import typing as t
9
9
import uuid
10
+ from collections import namedtuple
10
11
from functools import partial
11
12
from http import HTTPStatus
12
13
29
30
# optional dependencies
30
31
...
31
32
33
+ PendingInput = namedtuple ("PendingInput" , ["request_id" , "content" ])
34
+
32
35
33
36
class ExecutionStack :
34
37
"""Execution request stack.
@@ -39,7 +42,7 @@ class ExecutionStack:
39
42
"""
40
43
41
44
def __init__ (self ):
42
- self .__pending_inputs : dict [str , dict ] = {}
45
+ self .__pending_inputs : dict [str , PendingInput ] = {}
43
46
self .__tasks : dict [str , asyncio .Task ] = {}
44
47
45
48
def __del__ (self ):
@@ -78,7 +81,12 @@ def get(self, kernel_id: str, uid: str) -> t.Any:
78
81
raise ValueError (f"Request { uid } does not exists." )
79
82
80
83
if kernel_id in self .__pending_inputs :
81
- return self .__pending_inputs .pop (kernel_id )
84
+ get_logger ().info (f"Kernel '{ kernel_id } ' has a pending input." )
85
+ # Check the request id is the one matching the appearance of the input
86
+ # Otherwise another cell still looking for its results may capture the
87
+ # pending input
88
+ if uid == self .__pending_inputs [kernel_id ].request_id :
89
+ return self .__pending_inputs .pop (kernel_id ).content
82
90
83
91
if self .__tasks [uid ].done ():
84
92
task = self .__tasks .pop (uid )
@@ -102,11 +110,11 @@ def put(
102
110
uid = str (uuid .uuid4 ())
103
111
104
112
self .__tasks [uid ] = asyncio .create_task (
105
- _execute_task (uid , km , snippet , ycell , partial (self ._stdin_hook , km .kernel_id ))
113
+ _execute_task (uid , km , snippet , ycell , partial (self ._stdin_hook , km .kernel_id , uid ))
106
114
)
107
115
return uid
108
116
109
- def _stdin_hook (self , kernel_id : str , msg : dict ) -> None :
117
+ def _stdin_hook (self , kernel_id : str , request_id : str , msg : dict ) -> None :
110
118
"""Callback on stdin message.
111
119
112
120
It will register the pending input as temporary answer to the execution request.
@@ -119,10 +127,13 @@ def _stdin_hook(self, kernel_id: str, msg: dict) -> None:
119
127
120
128
header = msg ["header" ].copy ()
121
129
header ["date" ] = header ["date" ].isoformat ()
122
- self .__pending_inputs [kernel_id ] = {
123
- "parent_header" : header ,
124
- "input_request" : msg ["content" ],
125
- }
130
+ self .__pending_inputs [kernel_id ] = PendingInput (
131
+ request_id ,
132
+ {
133
+ "parent_header" : header ,
134
+ "input_request" : msg ["content" ],
135
+ },
136
+ )
126
137
127
138
128
139
async def _execute_task (
@@ -159,7 +170,6 @@ async def _execute_snippet(
159
170
ycell : y .Map | None ,
160
171
stdin_hook : t .Callable [[dict ], None ] | None ,
161
172
) -> dict [str , t .Any ]:
162
-
163
173
if ycell is not None :
164
174
# Reset cell
165
175
with ycell .doc .transaction ():
@@ -294,7 +304,7 @@ async def post(self, kernel_id: str) -> None:
294
304
status_code = HTTPStatus .INTERNAL_SERVER_ERROR , reason = msg
295
305
)
296
306
297
- notebook : YNotebook = await self ._ydoc .get_document (document_id = document_id , copy = False )
307
+ notebook : YNotebook = await self ._ydoc .get_document (room_id = document_id , copy = False )
298
308
299
309
if notebook is None :
300
310
msg = f"Document with ID { document_id } not found."
@@ -310,9 +320,7 @@ async def post(self, kernel_id: str) -> None:
310
320
raise tornado .web .HTTPError (status_code = HTTPStatus .NOT_FOUND , reason = msg ) # noqa: B904
311
321
else :
312
322
# Check if there is more than one cell
313
- try :
314
- next (ycells )
315
- except StopIteration :
323
+ if next (ycells , None ) is not None :
316
324
get_logger ().warning ("Multiple cells have the same ID '%s'." , cell_id )
317
325
318
326
if ycell ["cell_type" ] != "code" :
0 commit comments