Skip to content

Commit 49062c5

Browse files
committed
Adding the output processor.
1 parent f35aeec commit 49062c5

File tree

1 file changed

+233
-0
lines changed

1 file changed

+233
-0
lines changed
Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
import asyncio
2+
import json
3+
4+
from pycrdt import Map
5+
6+
from traitlets import LoggingConfigurable, Dict, Unicode
7+
8+
9+
class OutputProcessor(LoggingConfigurable):
10+
11+
_cell_ids = Dict(default_value={}) # a map from msg_id -> cell_id
12+
_cell_indices = Dict(default_value={}) # a map from cell_id -> cell index in notebook
13+
_file_id = Unicode(default_value=None, allow_None=True)
14+
15+
@property
16+
def settings(self):
17+
"""A shortcut for the Tornado web app settings."""
18+
return self.parent.parent.webapp.settings
19+
20+
@property
21+
def kernel_client(self):
22+
"""A shortcut to the kernel client this output processor is attached to."""
23+
return self.parent
24+
25+
@property
26+
def outputs_manager(self):
27+
"""A shortcut for the OutputsManager instance."""
28+
return self.settings["outputs_manager"]
29+
30+
@property
31+
def session_manager(self):
32+
"""A shortcut for the kernel session manager."""
33+
return self.settings["session_manager"]
34+
35+
@property
36+
def file_id_manager(self):
37+
"""A shortcut for the file id manager."""
38+
return self.settings["file_id_manager"]
39+
40+
@property
41+
def jupyter_server_ydoc(self):
42+
"""A shortcut for the jupyter server ydoc manager."""
43+
return self.settings["jupyter_server_ydoc"]
44+
45+
def clear(self, cell_id=None):
46+
"""Clear the state of the output processor.
47+
48+
This clears the state (saved msg_ids, cell_ids, cell indices) for the output
49+
processor. If cell_id is provided, only the state for that cell is cleared.
50+
"""
51+
if cell_id is None:
52+
self._cell_ids = {}
53+
self._cell_indices = {}
54+
else:
55+
msg_id = self.get_msg_id(cell_id)
56+
if (msg_id is not None) and (msg_id in self._cell_ids): del self._cell_ids[msg_id]
57+
if cell_id in self._cell_indices: del self._cell_indices[cell_id]
58+
59+
def set_cell_id(self, msg_id, cell_id):
60+
"""Set the cell_id for a msg_id."""
61+
self._cell_ids[msg_id] = cell_id
62+
63+
def get_cell_id(self, msg_id):
64+
"""Retrieve a cell_id from a parent msg_id."""
65+
return self._cell_ids.get(msg_id)
66+
67+
def get_msg_id(self, cell_id):
68+
"""Retrieve a msg_id from a cell_id."""
69+
return {v: k for k, v in self._cell_ids.items()}.get(cell_id)
70+
71+
# Incoming messages
72+
73+
def process_incoming_message(self, channel: str, msg: list[bytes]):
74+
"""Process incoming messages from the frontend.
75+
76+
Save the cell_id <-> msg_id mapping
77+
78+
msg = [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...]
79+
80+
This method is used to create a map between cell_id and msg_id.
81+
Incoming execute_request messages have both a cell_id and msg_id.
82+
When output messages are send back to the frontend, this map is used
83+
to find the cell_id for a given parent msg_id.
84+
"""
85+
header = json.loads(msg[0])
86+
msg_id = header["msg_id"]
87+
msg_type = header["msg_type"]
88+
metadata = json.loads(msg[2])
89+
cell_id = metadata.get("cellId")
90+
if cell_id is None:
91+
return
92+
93+
existing_msg_id = self.get_msg_id(cell_id)
94+
if existing_msg_id != msg_id: # cell is being re-run, clear output state
95+
if self._file_id is not None:
96+
self.log.info(f"Cell has been rerun, removing old outputs: {self._file_id} {cell_id}")
97+
self.clear(cell_id)
98+
self.outputs_manager.clear(file_id=self._file_id, cell_id=cell_id)
99+
self.log.info(f"Saving (msg_id, cell_id): ({msg_id} {cell_id})")
100+
self.set_cell_id(msg_id, cell_id)
101+
102+
# Outgoing messages
103+
104+
def process_outgoing_message(self, channel: str, msg: list[bytes]):
105+
"""Process outgoing messagers from the kernel."""
106+
dmsg = self.kernel_client.session.deserialize(msg)
107+
msg_type = dmsg["header"]["msg_type"]
108+
msg_id = dmsg["parent_header"]["msg_id"]
109+
content = dmsg["content"]
110+
cell_id = self.get_cell_id(msg_id)
111+
if cell_id is None:
112+
return
113+
asyncio.create_task(self.output_task(msg_type, cell_id, content))
114+
return None # Don't allow the original message to propagate to the frontend
115+
116+
async def output_task(self, msg_type, cell_id, content):
117+
"""A coroutine to handle output messages."""
118+
try:
119+
kernel_session = await self.session_manager.get_session(kernel_id=self.kernel_id)
120+
except: # what exception to catch?
121+
return
122+
else:
123+
path = kernel_session["path"]
124+
125+
file_id = self.file_id_manager.get_id(path)
126+
if file_id is None:
127+
return
128+
self._file_id = file_id
129+
try:
130+
notebook = await self.jupyter_server_ydoc.get_document(
131+
path=path,
132+
copy=False,
133+
file_format='json
134+
content_type='notebook'
135+
)
136+
except: # what exception to catch?
137+
return
138+
cells = notebook.ycells
139+
140+
cell_index, target_cell = self.find_cell(cell_id, cells)
141+
if target_cell is None:
142+
return
143+
144+
# Convert from the message spec to the nbformat output structure
145+
output = self.transform_output(msg_type, content, ydoc=False)
146+
output_url = self.outputs_manager.write(file_id, cell_id, output)
147+
nb_output = Map({
148+
"output_type": "display_data",
149+
"data": {
150+
'text/html': f'<a href="{output_url}">Output</a>'
151+
},
152+
"metadata": {
153+
"outputs_service": True
154+
}
155+
})
156+
target_cell["outputs"].append(nb_output)
157+
158+
def find_cell(self, cell_id, cells):
159+
"""Find a cell with a given cell_id in the list of cells.
160+
161+
This uses caching if we have seen the cell previously.
162+
"""
163+
# Find the target_cell and its cell_index and cache
164+
target_cell = None
165+
cell_index = None
166+
try:
167+
# See if we have a cached value for the cell_index
168+
cell_index = self._cell_indices[cell_id]
169+
target_cell = cells[cell_index]
170+
except KeyError:
171+
# Do a linear scan to find the cell
172+
self.log.info(f"Linear scan: {cell_id}")
173+
cell_index, target_cell = self.scan_cells(cell_id, cells)
174+
else:
175+
# Verify that the cached value still matches
176+
if target_cell["id"] != cell_id:
177+
self.log.info(f"Invalid cache hit: {cell_id}")
178+
cell_index, target_cell = self.scan_cells(cell_id, cells)
179+
else:
180+
self.log.info(f"Validated cache hit: {cell_id}")
181+
return cell_index, target_cell
182+
183+
def scan_cells(self, cell_id, cells):
184+
"""Find the cell with a given cell_id in the list of cells.
185+
186+
This does a simple linear scan of the cells, but in reverse order because
187+
we believe that users are more often running cells towards the end of the
188+
notebook.
189+
"""
190+
target_cell = None
191+
cell_index = None
192+
for i in reversed(range(0, len(cells))):
193+
cell = cells[i]
194+
if cell["id"] == cell_id:
195+
target_cell = cell
196+
cell_index = i
197+
self._cell_indices[cell_id] = cell_index
198+
break
199+
return cell_index, target_cell
200+
201+
def transform_output(self, msg_type, content, ydoc=False):
202+
"""Transform output from IOPub messages to the nbformat specification."""
203+
if ydoc:
204+
factory = Map
205+
else:
206+
factory = lambda x: x
207+
if msg_type == "stream":
208+
output = factory({
209+
"output_type": "stream",
210+
"text": content["text"],
211+
"name": content["name"]
212+
})
213+
elif msg_type == "display_data":
214+
output = factory({
215+
"output_type": "display_data",
216+
"data": content["data"],
217+
"metadata": content["metadata"]
218+
})
219+
elif msg_type == "execute_result":
220+
output = factory({
221+
"output_type": "execute_result",
222+
"data": content["data"],
223+
"metadata": content["metadata"],
224+
"execution_count": content["execution_count"]
225+
})
226+
elif msg_type == "error":
227+
output = factory({
228+
"output_type": "error",
229+
"traceback": content["traceback"],
230+
"ename": content["ename"],
231+
"evalue": content["evalue"]
232+
})
233+
return output

0 commit comments

Comments
 (0)