Skip to content

Commit 3eb4dfd

Browse files
committed
Do not replace cells/metadata which did not change (Python side)
1 parent 9a4a5b5 commit 3eb4dfd

File tree

1 file changed

+70
-13
lines changed

1 file changed

+70
-13
lines changed

jupyter_ydoc/ynotebook.py

Lines changed: 70 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33

44
import copy
55
from functools import partial
6-
from typing import Any, Callable, Dict, Optional
6+
from typing import Any, Callable, Dict, List, Optional
77
from uuid import uuid4
8+
from warnings import warn
89

910
from pycrdt import Array, Awareness, Doc, Map, Text
1011

@@ -102,8 +103,11 @@ def get_cell(self, index: int) -> Dict[str, Any]:
102103
:return: A cell.
103104
:rtype: Dict[str, Any]
104105
"""
106+
return self._cell_to_py(self._ycells[index])
107+
108+
def _cell_to_py(self, ycell: Map) -> Dict[str, Any]:
105109
meta = self._ymeta.to_py()
106-
cell = self._ycells[index].to_py()
110+
cell = ycell.to_py()
107111
cell.pop("execution_state", None)
108112
cast_all(cell, float, int) # cells coming from Yjs have e.g. execution_count as float
109113
if "id" in cell and meta["nbformat"] == 4 and meta["nbformat_minor"] <= 4:
@@ -234,7 +238,7 @@ def set(self, value: Dict) -> None:
234238
nb_without_cells = {key: value[key] for key in value.keys() if key != "cells"}
235239
nb = copy.deepcopy(nb_without_cells)
236240
cast_all(nb, int, float) # Yjs expects numbers to be floating numbers
237-
cells = value["cells"] or [
241+
new_cells = value["cells"] or [
238242
{
239243
"cell_type": "code",
240244
"execution_count": None,
@@ -245,26 +249,79 @@ def set(self, value: Dict) -> None:
245249
"id": str(uuid4()),
246250
}
247251
]
252+
old_ycells_by_id = {ycell["id"]: ycell for ycell in self._ycells}
248253

249254
with self._ydoc.transaction():
250-
# clear document
251-
self._ymeta.clear()
252-
self._ycells.clear()
255+
try:
256+
new_cell_list: List[tuple[Map, dict]] = []
257+
retained_cells = set()
258+
259+
# Determine cells to be retained
260+
for new_cell in new_cells:
261+
cell_id = new_cell.get("id")
262+
if cell_id and (old_ycell := old_ycells_by_id.get(cell_id)):
263+
old_cell = self._cell_to_py(old_ycell)
264+
if old_cell == new_cell:
265+
new_cell_list.append((old_ycell, old_cell))
266+
retained_cells.add(cell_id)
267+
continue
268+
# New or changed cell
269+
new_cell_list.append((self.create_ycell(new_cell), new_cell))
270+
271+
# First delete all non-retained cells
272+
if not retained_cells:
273+
# fast path if no cells were retained
274+
self._ycells.clear()
275+
else:
276+
index = 0
277+
for old_ycell in list(self._ycells):
278+
if old_ycell["id"] not in retained_cells:
279+
self._ycells.pop(index)
280+
else:
281+
index += 1
282+
283+
# Now add new cells
284+
index = 0
285+
for new_ycell, new_cell in new_cell_list:
286+
if len(self._ycells) > index:
287+
# we need to compare against a python cell to avoid
288+
# an extra transaction on new cells which are not yet
289+
# integrated into the ydoc document.
290+
if self._ycells[index]["id"] == new_cell.get("id"):
291+
# retained cell
292+
index += 1
293+
continue
294+
self._ycells.insert(index, new_ycell)
295+
index += 1
296+
297+
except Exception as e:
298+
# Fallback to total overwrite, warn to allow debugging
299+
warn(f"All cells were reloaded due to an error in granular reload logic: {e}")
300+
self._ycells.clear()
301+
self._ycells.extend([new_ycell for (new_ycell, _new_cell) in new_cell_list])
302+
253303
for key in [
254304
k for k in self._ystate.keys() if k not in ("dirty", "path", "document_id")
255305
]:
256306
del self._ystate[key]
257307

258-
# initialize document
259-
self._ycells.extend([self.create_ycell(cell) for cell in cells])
260-
self._ymeta["nbformat"] = nb.get("nbformat", NBFORMAT_MAJOR_VERSION)
261-
self._ymeta["nbformat_minor"] = nb.get("nbformat_minor", NBFORMAT_MINOR_VERSION)
308+
nbformat_major = nb.get("nbformat", NBFORMAT_MAJOR_VERSION)
309+
nbformat_minor = nb.get("nbformat_minor", NBFORMAT_MINOR_VERSION)
310+
311+
if self._ymeta.get("nbformat") != nbformat_major:
312+
self._ymeta["nbformat"] = nbformat_major
313+
314+
if self._ymeta.get("nbformat_minor") != nbformat_minor:
315+
self._ymeta["nbformat_minor"] = nbformat_minor
262316

317+
old_y_metadata = self._ymeta.get("metadata")
318+
old_metadata = old_y_metadata.to_py() if old_y_metadata else {}
263319
metadata = nb.get("metadata", {})
264-
metadata.setdefault("language_info", {"name": ""})
265-
metadata.setdefault("kernelspec", {"name": "", "display_name": ""})
266320

267-
self._ymeta["metadata"] = Map(metadata)
321+
if metadata != old_metadata:
322+
metadata.setdefault("language_info", {"name": ""})
323+
metadata.setdefault("kernelspec", {"name": "", "display_name": ""})
324+
self._ymeta["metadata"] = Map(metadata)
268325

269326
def observe(self, callback: Callable[[str, Any], None]) -> None:
270327
"""

0 commit comments

Comments
 (0)