33
44import copy
55from functools import partial
6- from typing import Any , Callable , Dict , Optional
6+ from typing import Any , Callable , Dict , List , Optional
77from uuid import uuid4
8+ from warnings import warn
89
910from pycrdt import Array , Awareness , Doc , Map , Text
1011
@@ -102,8 +103,11 @@ def get_cell(self, index: int) -> Dict[str, Any]:
102103 :return: A cell.
103104 :rtype: Dict[str, Any]
104105 """
106+ return self ._cell_to_py (self ._ycells [index ])
107+
108+ def _cell_to_py (self , ycell : Map ) -> Dict [str , Any ]:
105109 meta = self ._ymeta .to_py ()
106- cell = self . _ycells [ index ] .to_py ()
110+ cell = ycell .to_py ()
107111 cell .pop ("execution_state" , None )
108112 cast_all (cell , float , int ) # cells coming from Yjs have e.g. execution_count as float
109113 if "id" in cell and meta ["nbformat" ] == 4 and meta ["nbformat_minor" ] <= 4 :
@@ -234,7 +238,7 @@ def set(self, value: Dict) -> None:
234238 nb_without_cells = {key : value [key ] for key in value .keys () if key != "cells" }
235239 nb = copy .deepcopy (nb_without_cells )
236240 cast_all (nb , int , float ) # Yjs expects numbers to be floating numbers
237- cells = value ["cells" ] or [
241+ new_cells = value ["cells" ] or [
238242 {
239243 "cell_type" : "code" ,
240244 "execution_count" : None ,
@@ -245,26 +249,79 @@ def set(self, value: Dict) -> None:
245249 "id" : str (uuid4 ()),
246250 }
247251 ]
252+ old_ycells_by_id = {ycell ["id" ]: ycell for ycell in self ._ycells }
248253
249254 with self ._ydoc .transaction ():
250- # clear document
251- self ._ymeta .clear ()
252- self ._ycells .clear ()
255+ try :
256+ new_cell_list : List [tuple [Map , dict ]] = []
257+ retained_cells = set ()
258+
259+ # Determine cells to be retained
260+ for new_cell in new_cells :
261+ cell_id = new_cell .get ("id" )
262+ if cell_id and (old_ycell := old_ycells_by_id .get (cell_id )):
263+ old_cell = self ._cell_to_py (old_ycell )
264+ if old_cell == new_cell :
265+ new_cell_list .append ((old_ycell , old_cell ))
266+ retained_cells .add (cell_id )
267+ continue
268+ # New or changed cell
269+ new_cell_list .append ((self .create_ycell (new_cell ), new_cell ))
270+
271+ # First delete all non-retained cells
272+ if not retained_cells :
273+ # fast path if no cells were retained
274+ self ._ycells .clear ()
275+ else :
276+ index = 0
277+ for old_ycell in list (self ._ycells ):
278+ if old_ycell ["id" ] not in retained_cells :
279+ self ._ycells .pop (index )
280+ else :
281+ index += 1
282+
283+ # Now add new cells
284+ index = 0
285+ for new_ycell , new_cell in new_cell_list :
286+ if len (self ._ycells ) > index :
287+ # we need to compare against a python cell to avoid
288+ # an extra transaction on new cells which are not yet
289+ # integrated into the ydoc document.
290+ if self ._ycells [index ]["id" ] == new_cell .get ("id" ):
291+ # retained cell
292+ index += 1
293+ continue
294+ self ._ycells .insert (index , new_ycell )
295+ index += 1
296+
297+ except Exception as e :
298+ # Fallback to total overwrite, warn to allow debugging
299+ warn (f"All cells were reloaded due to an error in granular reload logic: { e } " )
300+ self ._ycells .clear ()
301+ self ._ycells .extend ([new_ycell for (new_ycell , _new_cell ) in new_cell_list ])
302+
253303 for key in [
254304 k for k in self ._ystate .keys () if k not in ("dirty" , "path" , "document_id" )
255305 ]:
256306 del self ._ystate [key ]
257307
258- # initialize document
259- self ._ycells .extend ([self .create_ycell (cell ) for cell in cells ])
260- self ._ymeta ["nbformat" ] = nb .get ("nbformat" , NBFORMAT_MAJOR_VERSION )
261- self ._ymeta ["nbformat_minor" ] = nb .get ("nbformat_minor" , NBFORMAT_MINOR_VERSION )
308+ nbformat_major = nb .get ("nbformat" , NBFORMAT_MAJOR_VERSION )
309+ nbformat_minor = nb .get ("nbformat_minor" , NBFORMAT_MINOR_VERSION )
310+
311+ if self ._ymeta .get ("nbformat" ) != nbformat_major :
312+ self ._ymeta ["nbformat" ] = nbformat_major
313+
314+ if self ._ymeta .get ("nbformat_minor" ) != nbformat_minor :
315+ self ._ymeta ["nbformat_minor" ] = nbformat_minor
262316
317+ old_y_metadata = self ._ymeta .get ("metadata" )
318+ old_metadata = old_y_metadata .to_py () if old_y_metadata else {}
263319 metadata = nb .get ("metadata" , {})
264- metadata .setdefault ("language_info" , {"name" : "" })
265- metadata .setdefault ("kernelspec" , {"name" : "" , "display_name" : "" })
266320
267- self ._ymeta ["metadata" ] = Map (metadata )
321+ if metadata != old_metadata :
322+ metadata .setdefault ("language_info" , {"name" : "" })
323+ metadata .setdefault ("kernelspec" , {"name" : "" , "display_name" : "" })
324+ self ._ymeta ["metadata" ] = Map (metadata )
268325
269326 def observe (self , callback : Callable [[str , Any ], None ]) -> None :
270327 """
0 commit comments