diff --git a/projects/jupyter-server-ydoc/jupyter_server_ydoc/stores.py b/projects/jupyter-server-ydoc/jupyter_server_ydoc/stores.py index 0e17e855..5e889b6b 100644 --- a/projects/jupyter-server-ydoc/jupyter_server_ydoc/stores.py +++ b/projects/jupyter-server-ydoc/jupyter_server_ydoc/stores.py @@ -5,6 +5,8 @@ from pycrdt.store import TempFileYStore as _TempFileYStore from traitlets import Int, Unicode from traitlets.config import LoggingConfigurable +import importlib +from typing import Callable class TempFileYStore(_TempFileYStore): @@ -15,6 +17,30 @@ class SQLiteYStoreMetaclass(type(LoggingConfigurable), type(_SQLiteYStore)): # pass +def import_from_dotted_path(dotted_path: str) -> Callable | None: + """Import a function from a dotted import path. + + Args: + dotted_path: String like 'module.submodule.function_name' + + Returns: + The imported function + + Raises: + ImportError: If the module or function cannot be imported + AttributeError: If the function doesn't exist in the module + """ + if not dotted_path: + return None + + try: + module_path, function_name = dotted_path.rsplit(".", 1) + module = importlib.import_module(module_path) + return getattr(module, function_name) + except (ValueError, ImportError, AttributeError) as e: + raise ImportError(f"Could not import '{dotted_path}': {e}") + + class SQLiteYStore(LoggingConfigurable, _SQLiteYStore, metaclass=SQLiteYStoreMetaclass): db_path = Unicode( ".jupyter_ystore.db", @@ -30,3 +56,53 @@ class SQLiteYStore(LoggingConfigurable, _SQLiteYStore, metaclass=SQLiteYStoreMet help="""The document time-to-live in seconds. Defaults to None (document history is never cleared).""", ) + + compress_function = Unicode( + "", + config=True, + help="""Dotted import path to compression function. The function should accept bytes + and return compressed bytes. Defaults to None (no compression).""", + ) + + decompress_function = Unicode( + "", + config=True, + help="""Dotted import path to decompression function. The function should accept + compressed bytes and return decompressed bytes. Defaults to None (no decompression).""", + ) + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._setup_compression() + + def _setup_compression(self): + """Set up compression callbacks if both compress and decompress paths are provided.""" + if not self.compress_function or not self.decompress_function: + # If either is empty, don't set up compression + if self.compress_function or self.decompress_function: + self.log.warning( + "Both compress_function and decompress_function must be specified " + "to enable compression. Currently only one is set." + ) + return + + try: + # Import the compression functions + compress_func = import_from_dotted_path(self.compress_function) + decompress_func = import_from_dotted_path(self.decompress_function) + + # Validate that they are callable + if not callable(compress_func) or not callable(decompress_func): + raise TypeError("Both compression functions must be callable") + + # Register the compression callbacks + self.register_compression_callbacks(compress_func, decompress_func) + self.log.info( + f"Registered compression callbacks: {self.compress_function} / {self.decompress_function}" + ) + except ImportError as e: + self.log.error(f"Failed to import compression functions: {e}") + except TypeError as e: + self.log.error(f"Invalid compression functions: {e}") + except Exception as e: + self.log.error(f"Unexpected error setting up compression: {e}") diff --git a/tests/test_app.py b/tests/test_app.py index b46b2bb7..12e81402 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -77,6 +77,30 @@ async def test_document_ttl_from_settings(rtc_create_mock_document_room, jp_conf assert store.document_ttl == 3600 +async def test_compression_from_settings(rtc_create_mock_document_room, jp_configurable_serverapp): + argv = [ + "--SQLiteYStore.compress_function=gzip.compress", + "--SQLiteYStore.decompress_function=gzip.decompress", + ] + + app = jp_configurable_serverapp(argv=argv) + + id = "test-compression" + content = "test_compression_content" + rtc_create_SQLite_store = rtc_create_SQLite_store_factory(app) + store = await rtc_create_SQLite_store("file", id, content) + + assert store.compress_function == "gzip.compress" + assert store.decompress_function == "gzip.decompress" + + test_data = b"Hello, world! This is test data for compression." + compressed = store._compress(test_data) + decompressed = store._decompress(compressed) + + assert compressed != test_data + assert decompressed == test_data + + @pytest.mark.parametrize("copy", [True, False]) async def test_get_document_file(rtc_create_file, jp_serverapp, copy): path, content = await rtc_create_file("test.txt", "test", store=True)