Skip to content

Commit 91d2742

Browse files
Fix cell modifications causing cell output reload and shift to active cell index (#360)
* Add a failing test case for granular modifications Update tests to expect specific events * Implement granular cell updates on Python side Fix * Move code to a private method * Add type guard, fix text removal * Apply suggestions from code review Co-authored-by: David Brochart <[email protected]> * Increase test coverage --------- Co-authored-by: David Brochart <[email protected]>
1 parent dd1fb84 commit 91d2742

File tree

2 files changed

+128
-3
lines changed

2 files changed

+128
-3
lines changed

jupyter_ydoc/ynotebook.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
# The default minor version of the notebook format.
1818
NBFORMAT_MINOR_VERSION = 5
1919

20+
_CELL_KEY_TYPE_MAP = {"metadata": Map, "source": Text, "outputs": Array}
21+
2022

2123
class YNotebook(YBaseDoc):
2224
"""
@@ -249,7 +251,7 @@ def set(self, value: dict) -> None:
249251
"id": str(uuid4()),
250252
}
251253
]
252-
old_ycells_by_id = {ycell["id"]: ycell for ycell in self._ycells}
254+
old_ycells_by_id: dict[str, Map] = {ycell["id"]: ycell for ycell in self._ycells}
253255

254256
with self._ydoc.transaction():
255257
new_cell_list: list[dict] = []
@@ -260,7 +262,11 @@ def set(self, value: dict) -> None:
260262
cell_id = new_cell.get("id")
261263
if cell_id and (old_ycell := old_ycells_by_id.get(cell_id)):
262264
old_cell = self._cell_to_py(old_ycell)
263-
if old_cell == new_cell:
265+
updated_granularly = self._update_cell(
266+
old_cell=old_cell, new_cell=new_cell, old_ycell=old_ycell
267+
)
268+
269+
if updated_granularly:
264270
new_cell_list.append(old_cell)
265271
retained_cells.add(cell_id)
266272
continue
@@ -324,3 +330,57 @@ def observe(self, callback: Callable[[str, Any], None]) -> None:
324330
self._subscriptions[self._ystate] = self._ystate.observe(partial(callback, "state"))
325331
self._subscriptions[self._ymeta] = self._ymeta.observe_deep(partial(callback, "meta"))
326332
self._subscriptions[self._ycells] = self._ycells.observe_deep(partial(callback, "cells"))
333+
334+
def _update_cell(self, old_cell: dict, new_cell: dict, old_ycell: Map) -> bool:
335+
if old_cell == new_cell:
336+
return True
337+
# attempt to update cell granularly
338+
old_keys = set(old_cell.keys())
339+
new_keys = set(new_cell.keys())
340+
341+
shared_keys = old_keys & new_keys
342+
removed_keys = old_keys - new_keys
343+
added_keys = new_keys - old_keys
344+
345+
for key in shared_keys:
346+
if old_cell[key] != new_cell[key]:
347+
value = new_cell[key]
348+
if key == "output" and value:
349+
# outputs require complex handling - some have Text type nested;
350+
# for now skip creating them; clearing all outputs is fine
351+
return False
352+
353+
if key in _CELL_KEY_TYPE_MAP:
354+
kind = _CELL_KEY_TYPE_MAP[key]
355+
356+
if not isinstance(old_ycell[key], kind):
357+
# if our assumptions about types do not hold, fall back to hard update
358+
return False
359+
360+
if kind == Text:
361+
old: Text = old_ycell[key]
362+
old.clear()
363+
old += value
364+
elif kind == Array:
365+
old: Array = old_ycell[key]
366+
old.clear()
367+
old.extend(value)
368+
elif kind == Map:
369+
old: Map = old_ycell[key]
370+
old.clear()
371+
old.update(value)
372+
else:
373+
old_ycell[key] = new_cell[key]
374+
375+
for key in removed_keys:
376+
del old_ycell[key]
377+
378+
for key in added_keys:
379+
if key in _CELL_KEY_TYPE_MAP:
380+
# we hard-reload cells when keys that require nested types get added
381+
# to allow the frontend to connect observers; this could be changed
382+
# in the future, once frontends learn how to observe all changes
383+
return False
384+
else:
385+
old_ycell[key] = new_cell[key]
386+
return True

tests/test_ynotebook.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
# Copyright (c) Jupyter Development Team.
22
# Distributed under the terms of the Modified BSD License.
33

4-
from pycrdt import Map
4+
5+
from pycrdt import ArrayEvent, Map, MapEvent, TextEvent
6+
from pytest import mark
57

68
from jupyter_ydoc import YNotebook
79

@@ -114,3 +116,66 @@ def record_changes(topic, event):
114116
{"delete": 1},
115117
{"insert": [AnyInstanceOf(Map)]},
116118
]
119+
120+
121+
@mark.parametrize(
122+
"modifications, expected_events",
123+
[
124+
# modifications of single attributes
125+
([["source", "'b'"]], {TextEvent}),
126+
([["outputs", []]], {ArrayEvent}),
127+
([["execution_count", 2]], {MapEvent}),
128+
([["metadata", {"tags": []}]], {MapEvent}),
129+
([["new_key", "test"]], {MapEvent}),
130+
# multi-attribute modifications
131+
([["source", "10"], ["execution_count", 10]], {TextEvent, MapEvent}),
132+
],
133+
)
134+
def test_modify_single_cell(modifications, expected_events):
135+
nb = YNotebook()
136+
nb.set(
137+
{
138+
"cells": [
139+
{
140+
"id": "8800f7d8-6cad-42ef-a339-a9c185ffdd54",
141+
"cell_type": "code",
142+
"source": "'a'",
143+
"metadata": {"tags": ["test-tag"]},
144+
"outputs": [{"name": "stdout", "output_type": "stream", "text": ["a\n"]}],
145+
"execution_count": 1,
146+
},
147+
]
148+
}
149+
)
150+
151+
# Get the model as Python object
152+
model = nb.get()
153+
154+
# Make changes
155+
for modification in modifications:
156+
key, new_value = modification
157+
model["cells"][0][key] = new_value
158+
159+
changes = []
160+
161+
def record_changes(topic, event):
162+
changes.append((topic, event))
163+
164+
nb.observe(record_changes)
165+
nb.set(model)
166+
167+
for modification in modifications:
168+
key, new_value = modification
169+
after = nb.ycells[0][key]
170+
after_py = after.to_py() if hasattr(after, "to_py") else after
171+
assert after_py == new_value
172+
173+
# there should be only one change
174+
assert len(changes) == 1
175+
cell_events = [e for t, e in changes if t == "cells"]
176+
# and it should be a cell change
177+
assert len(cell_events) == 1
178+
# but it should be a change to cell data, not a change to the cell list
179+
events = cell_events[0]
180+
assert len(events) == len(expected_events)
181+
assert {type(e) for e in events} == expected_events

0 commit comments

Comments
 (0)