Skip to content

Commit 2b65f40

Browse files
authored
Handle update_display_data kernel message in notebooks (#104)
* Added a method to get all outputs for a cell. * Updated handler to accommodate all outputs * Updated to sort outputs by last modified time * Updated handler * lint * Updated to return jsonl response * Sorting outputs by output index, simplified length check * Moved placeholder to a separate function, ruff format. * Updated method name * WIP: handle update_display_data * Fixed outputs index check * Removed debug statements * Moved output index logic to a separate method, added tests * Fixed failing test * Removed stale test * fmt * Fixed missing code after rebase * Clean up after rebase * Added output index tracker * Updated tests * Updates tests * Remove unused imports * Added a state to track display_ids, added tests * Removed index tracker * Renamed cell state vars * Moved statement close to logic
1 parent 866b683 commit 2b65f40

File tree

5 files changed

+141
-53
lines changed

5 files changed

+141
-53
lines changed

jupyter_server_documents/kernels/kernel_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ async def handle_document_related_message(self, msg: t.List[bytes]) -> t.Optiona
273273
target_cell["execution_count"] = execution_count
274274
break
275275

276-
case "stream" | "display_data" | "execute_result" | "error":
276+
case "stream" | "display_data" | "execute_result" | "error" | "update_display_data":
277277
if cell_id:
278278
# Process specific output messages through an optional processor
279279
if self.output_processor and cell_id:

jupyter_server_documents/outputs/manager.py

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414
class OutputsManager(LoggingConfigurable):
1515
_last_output_index = Dict(default_value={})
16+
_output_index_by_display_id = Dict(default_value={})
17+
_display_ids_by_cell_id = Dict(default_value={})
1618
_stream_count = Dict(default_value={})
1719

1820
outputs_path = Instance(PurePath, help="The local runtime dir")
@@ -33,6 +35,38 @@ def _build_path(self, file_id, cell_id=None, output_index=None):
3335
if output_index is not None:
3436
path = path / f"{output_index}.output"
3537
return path
38+
39+
def _compute_output_index(self, cell_id, display_id=None):
40+
"""
41+
Computes next output index for a cell.
42+
43+
Args:
44+
cell_id (str): The cell identifier
45+
display_id (str, optional): A display identifier. Defaults to None.
46+
47+
Returns:
48+
int: The output index
49+
"""
50+
last_index = self._last_output_index.get(cell_id, -1)
51+
if display_id:
52+
if cell_id not in self._display_ids_by_cell_id:
53+
self._display_ids_by_cell_id[cell_id] = set([display_id])
54+
else:
55+
self._display_ids_by_cell_id[cell_id].add(display_id)
56+
index = self._output_index_by_display_id.get(display_id)
57+
if index is None:
58+
index = last_index + 1
59+
self._last_output_index[cell_id] = index
60+
self._output_index_by_display_id[display_id] = index
61+
else:
62+
index = last_index + 1
63+
self._last_output_index[cell_id] = index
64+
65+
return index
66+
67+
def get_output_index(self, display_id: str):
68+
"""Returns output index for a cell by display_id"""
69+
return self._output_index_by_display_id.get(display_id)
3670

3771
def get_output(self, file_id, cell_id, output_index):
3872
"""Get an output by file_id, cell_id, and output_index."""
@@ -77,23 +111,21 @@ def get_stream(self, file_id, cell_id):
77111
with open(path, "r", encoding="utf-8") as f:
78112
output = f.read()
79113
return output
80-
81-
def write(self, file_id, cell_id, output):
114+
115+
def write(self, file_id, cell_id, output, display_id=None):
82116
"""Write a new output for file_id and cell_id.
83117
84118
Returns a placeholder output (pycrdt.Map) or None if no placeholder
85119
output should be written to the ydoc.
86120
"""
87-
placeholder = self.write_output(file_id, cell_id, output)
121+
placeholder = self.write_output(file_id, cell_id, output, display_id)
88122
if output["output_type"] == "stream" and self.stream_limit is not None:
89123
placeholder = self.write_stream(file_id, cell_id, output, placeholder)
90124
return placeholder
91125

92-
def write_output(self, file_id, cell_id, output):
126+
def write_output(self, file_id, cell_id, output, display_id=None):
93127
self._ensure_path(file_id, cell_id)
94-
last_index = self._last_output_index.get(cell_id, -1)
95-
index = last_index + 1
96-
self._last_output_index[cell_id] = index
128+
index = self._compute_output_index(cell_id, display_id)
97129
path = self._build_path(file_id, cell_id, index)
98130
data = json.dumps(output, ensure_ascii=False)
99131
with open(path, "w", encoding="utf-8") as f:
@@ -134,13 +166,15 @@ def clear(self, file_id, cell_id=None):
134166
"""Clear the state of the manager."""
135167
if cell_id is None:
136168
self._stream_count = {}
137-
path = self._build_path(file_id)
138169
else:
139-
try:
140-
del self._stream_count[cell_id]
141-
except KeyError:
142-
pass
143-
path = self._build_path(file_id, cell_id)
170+
self._stream_count.pop(cell_id, None)
171+
self._last_output_index.pop(cell_id, None)
172+
173+
display_ids = self._display_ids_by_cell_id.get(cell_id, [])
174+
for display_id in display_ids:
175+
self._output_index_by_display_id.pop(display_id, None)
176+
177+
path = self._build_path(file_id, cell_id)
144178
try:
145179
shutil.rmtree(path)
146180
except FileNotFoundError:

jupyter_server_documents/outputs/output_processor.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import asyncio
2-
import json
32

43
from pycrdt import Map
54

6-
from traitlets import Dict, Unicode, Bool, Instance
5+
from traitlets import Unicode, Bool
76
from traitlets.config import LoggingConfigurable
87
from jupyter_server_documents.kernels.message_cache import KernelMessageCache
98

@@ -92,9 +91,13 @@ async def output_task(self, msg_type, cell_id, content):
9291
# TODO: The session manager may have multiple notebooks connected to the kernel
9392
# but currently get_session only returns the first. We need to fix this and
9493
# find the notebook with the right cell_id.
95-
kernel_session = await self.session_manager.get_session(kernel_id=self.parent.parent.kernel_id)
94+
kernel_session = await self.session_manager.get_session(
95+
kernel_id=self.parent.parent.kernel_id
96+
)
9697
except Exception as e:
97-
self.log.error(f"An exception occurred when processing output for cell {cell_id}")
98+
self.log.error(
99+
f"An exception occurred when processing output for cell {cell_id}"
100+
)
98101
self.log.exception(e)
99102
return
100103
else:
@@ -106,10 +109,11 @@ async def output_task(self, msg_type, cell_id, content):
106109
return
107110
self._file_id = file_id
108111

112+
display_id = content.get("transient", {}).get("display_id")
109113
# Convert from the message spec to the nbformat output structure
110-
if self.use_outputs_service:
114+
if self.use_outputs_service:
111115
output = self.transform_output(msg_type, content, ydoc=False)
112-
output = self.outputs_manager.write(file_id, cell_id, output)
116+
output = self.outputs_manager.write(file_id, cell_id, output, display_id)
113117
else:
114118
output = self.transform_output(msg_type, content, ydoc=True)
115119

@@ -125,8 +129,12 @@ async def output_task(self, msg_type, cell_id, content):
125129
# Write the outputs to the ydoc cell.
126130
_, target_cell = notebook.find_cell(cell_id)
127131
if target_cell is not None and output is not None:
128-
target_cell["outputs"].append(output)
129-
self.log.info(f"Write output to ydoc: {path} {cell_id} {output}")
132+
output_index = self.outputs_manager.get_output_index(display_id) if display_id else None
133+
if output_index is not None:
134+
target_cell["outputs"][output_index] = output
135+
else:
136+
target_cell["outputs"].append(output)
137+
self.log.info(f"Wrote output to ydoc: {path} {cell_id} {output}")
130138

131139

132140
def transform_output(self, msg_type, content, ydoc=False):
@@ -141,7 +149,7 @@ def transform_output(self, msg_type, content, ydoc=False):
141149
"text": content["text"],
142150
"name": content["name"]
143151
})
144-
elif msg_type == "display_data":
152+
elif msg_type == "display_data" or msg_type == "update_display_data":
145153
output = factory({
146154
"output_type": "display_data",
147155
"data": content["data"],

jupyter_server_documents/tests/test_output_processor.py

Lines changed: 3 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -22,32 +22,6 @@ def test_instantiation():
2222
op = OutputProcessor()
2323
assert isinstance(op, OutputProcessor)
2424

25-
def test_incoming_message():
26-
"""Test incoming message processing."""
27-
with TemporaryDirectory() as td:
28-
op = TestOutputProcessor()
29-
om = OutputsManager()
30-
op.settings["outputs_manager"] = om
31-
op.outputs_path = Path(td) / "outputs"
32-
# Simulate the running of a cell
33-
cell_id = str(uuid4())
34-
msg_id, msg = create_incoming_message(cell_id)
35-
op.process_incoming_message('shell', msg)
36-
assert op.get_cell_id(msg_id) == cell_id
37-
assert op.get_msg_id(cell_id) == msg_id
38-
# Clear the cell_id
39-
op.clear(cell_id)
40-
assert op.get_cell_id(msg_id) is None
41-
assert op.get_msg_id(cell_id) is None
42-
# Simulate the running of a cell
43-
cell_id = str(uuid4())
44-
msg_id, msg = create_incoming_message(cell_id)
45-
op.process_incoming_message('shell', msg)
46-
assert op.get_cell_id(msg_id) == cell_id
47-
assert op.get_msg_id(cell_id) == msg_id
48-
# # Run it again without clearing to ensure it self clears
49-
cell_id = str(uuid4())
50-
msg_id, msg = create_incoming_message(cell_id)
51-
op.process_incoming_message('shell', msg)
52-
assert op.get_cell_id(msg_id) == cell_id
53-
assert op.get_msg_id(cell_id) == msg_id
25+
# TODO: Implement this
26+
def test_output_task():
27+
pass

jupyter_server_documents/tests/test_outputs_manager.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ def test_stream():
4747
file_id = str(uuid4())
4848
cell_id = str(uuid4())
4949
for s in streams:
50-
op.write_stream(file_id, cell_id, s)
50+
placeholder = op.write_output(file_id, cell_id, s)
51+
op.write_stream(file_id, cell_id, s, placeholder)
5152
assert op.get_stream(file_id, cell_id) == text
5253

5354
def test_display_data():
@@ -91,3 +92,74 @@ def file_not_found():
9192
op.get_output('a','b',0)
9293
with pytest.raises(FileNotFoundError):
9394
op.get_stream('a','b')
95+
96+
97+
def test__compute_output_index_basic():
98+
"""
99+
Test basic output index allocation for a cell without display ID
100+
"""
101+
op = OutputsManager()
102+
103+
# First output for a cell should be 0
104+
assert op._compute_output_index('cell1') == 0
105+
assert op._compute_output_index('cell1') == 1
106+
assert op._compute_output_index('cell1') == 2
107+
108+
def test__compute_output_index_with_display_id():
109+
"""
110+
Test output index allocation with display IDs
111+
"""
112+
op = OutputsManager()
113+
114+
# First output for a cell with display ID
115+
assert op._compute_output_index('cell1', 'display1') == 0
116+
117+
# Subsequent calls with same display ID should return the same index
118+
assert op._compute_output_index('cell1', 'display1') == 0
119+
120+
# Different display ID should get a new index
121+
assert op._compute_output_index('cell1', 'display2') == 1
122+
123+
124+
def test__compute_output_index_multiple_cells():
125+
"""
126+
Test output index allocation across multiple cells
127+
"""
128+
op = OutputsManager()
129+
130+
assert op._compute_output_index('cell1') == 0
131+
assert op._compute_output_index('cell1') == 1
132+
assert op._compute_output_index('cell2') == 0
133+
assert op._compute_output_index('cell2') == 1
134+
135+
def test_display_id_index_retrieval():
136+
"""
137+
Test retrieving output index for a display ID
138+
"""
139+
op = OutputsManager()
140+
141+
op._compute_output_index('cell1', 'display1')
142+
143+
assert op.get_output_index('display1') == 0
144+
assert op.get_output_index('non_existent_display') is None
145+
146+
def test_display_ids():
147+
"""
148+
Test tracking of display IDs for a cell
149+
"""
150+
op = OutputsManager()
151+
152+
# Allocate multiple display IDs for a cell
153+
op._compute_output_index('cell1', 'display1')
154+
op._compute_output_index('cell1', 'display2')
155+
156+
# Verify display IDs are tracked
157+
assert 'cell1' in op._display_ids_by_cell_id
158+
assert set(op._display_ids_by_cell_id['cell1']) == {'display1', 'display2'}
159+
160+
# Clear cell indices
161+
op.clear('file1', 'cell1')
162+
163+
# Verify display IDs are cleared
164+
assert 'display1' not in op._display_ids_by_cell_id
165+
assert 'display2' not in op._display_ids_by_cell_id

0 commit comments

Comments
 (0)