diff --git a/src/datajoint/__init__.py b/src/datajoint/__init__.py index e4077816c..ae8d308d2 100644 --- a/src/datajoint/__init__.py +++ b/src/datajoint/__init__.py @@ -60,18 +60,18 @@ "ValidationResult", ] +# ============================================================================= +# Eager imports — core functionality needed immediately +# ============================================================================= from . import errors from . import migrate -from .admin import kill from .codecs import ( Codec, get_codec, list_codecs, ) from .blob import MatCell, MatStruct -from .cli import cli from .connection import Connection, conn -from .diagram import Diagram from .errors import DataJointError from .expression import AndList, Not, Top, U from .hash import key_hash @@ -83,5 +83,38 @@ from .user_tables import Computed, Imported, Lookup, Manual, Part from .version import __version__ -ERD = Di = Diagram # Aliases for Diagram -schema = Schema # Aliases for Schema +schema = Schema # Alias for Schema + +# ============================================================================= +# Lazy imports — heavy dependencies loaded on first access +# ============================================================================= +# These modules import heavy dependencies (networkx, matplotlib, click, pymysql) +# that slow down `import datajoint`. They are loaded on demand. + +_lazy_modules = { + # Diagram imports networkx and matplotlib + "Diagram": (".diagram", "Diagram"), + "Di": (".diagram", "Diagram"), + "ERD": (".diagram", "Diagram"), + "diagram": (".diagram", None), # Return the module itself + # kill imports pymysql via connection + "kill": (".admin", "kill"), + # cli imports click + "cli": (".cli", "cli"), +} + + +def __getattr__(name: str): + """Lazy import for heavy dependencies.""" + if name in _lazy_modules: + module_path, attr_name = _lazy_modules[name] + import importlib + + module = importlib.import_module(module_path, __package__) + # If attr_name is None, return the module itself + attr = module if attr_name is None else getattr(module, attr_name) + # Cache in module __dict__ to avoid repeated __getattr__ calls + # and to override the submodule that importlib adds automatically + globals()[name] = attr + return attr + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/tests/unit/test_lazy_imports.py b/tests/unit/test_lazy_imports.py new file mode 100644 index 000000000..7c1dc4c9e --- /dev/null +++ b/tests/unit/test_lazy_imports.py @@ -0,0 +1,121 @@ +""" +Tests for lazy import behavior. + +These tests verify that heavy dependencies (networkx, matplotlib, click) +are not loaded until their associated features are accessed. +""" + +import sys + + +def test_lazy_diagram_import(): + """Diagram module should not be loaded until dj.Diagram is accessed.""" + # Remove datajoint from sys.modules to get fresh import + modules_to_remove = [key for key in sys.modules if key.startswith("datajoint")] + for mod in modules_to_remove: + del sys.modules[mod] + + # Import datajoint + import datajoint as dj + + # Diagram module should not be loaded yet + assert "datajoint.diagram" not in sys.modules, "diagram module loaded eagerly" + + # Access Diagram - should trigger lazy load + Diagram = dj.Diagram + assert "datajoint.diagram" in sys.modules, "diagram module not loaded after access" + assert Diagram.__name__ == "Diagram" + + +def test_lazy_admin_import(): + """Admin module should not be loaded until dj.kill is accessed.""" + # Remove datajoint from sys.modules to get fresh import + modules_to_remove = [key for key in sys.modules if key.startswith("datajoint")] + for mod in modules_to_remove: + del sys.modules[mod] + + # Import datajoint + import datajoint as dj + + # Admin module should not be loaded yet + assert "datajoint.admin" not in sys.modules, "admin module loaded eagerly" + + # Access kill - should trigger lazy load + kill = dj.kill + assert "datajoint.admin" in sys.modules, "admin module not loaded after access" + assert callable(kill) + + +def test_lazy_cli_import(): + """CLI module should not be loaded until dj.cli is accessed.""" + # Remove datajoint from sys.modules to get fresh import + modules_to_remove = [key for key in sys.modules if key.startswith("datajoint")] + for mod in modules_to_remove: + del sys.modules[mod] + + # Import datajoint + import datajoint as dj + + # CLI module should not be loaded yet + assert "datajoint.cli" not in sys.modules, "cli module loaded eagerly" + + # Access cli - should trigger lazy load and return the function + cli_func = dj.cli + assert "datajoint.cli" in sys.modules, "cli module not loaded after access" + assert callable(cli_func), "dj.cli should be callable (the cli function)" + + +def test_diagram_module_access(): + """dj.diagram should return the diagram module for accessing module-level attrs.""" + # Remove datajoint from sys.modules to get fresh import + modules_to_remove = [key for key in sys.modules if key.startswith("datajoint")] + for mod in modules_to_remove: + del sys.modules[mod] + + import datajoint as dj + + # Access dj.diagram should return the module + diagram_module = dj.diagram + assert hasattr(diagram_module, "diagram_active"), "diagram module should have diagram_active" + assert hasattr(diagram_module, "Diagram"), "diagram module should have Diagram class" + + +def test_diagram_aliases(): + """Di and ERD should be aliases for Diagram.""" + # Remove datajoint from sys.modules to get fresh import + modules_to_remove = [key for key in sys.modules if key.startswith("datajoint")] + for mod in modules_to_remove: + del sys.modules[mod] + + import datajoint as dj + + # All aliases should resolve to the same class + assert dj.Diagram is dj.Di + assert dj.Diagram is dj.ERD + + +def test_core_imports_available(): + """Core functionality should be available immediately after import.""" + # Remove datajoint from sys.modules to get fresh import + modules_to_remove = [key for key in sys.modules if key.startswith("datajoint")] + for mod in modules_to_remove: + del sys.modules[mod] + + import datajoint as dj + + # Core classes should be available without triggering lazy loads + assert hasattr(dj, "Schema") + assert hasattr(dj, "Table") + assert hasattr(dj, "Manual") + assert hasattr(dj, "Lookup") + assert hasattr(dj, "Computed") + assert hasattr(dj, "Imported") + assert hasattr(dj, "Part") + assert hasattr(dj, "Connection") + assert hasattr(dj, "config") + assert hasattr(dj, "errors") + + # Heavy modules should still not be loaded + assert "datajoint.diagram" not in sys.modules + assert "datajoint.admin" not in sys.modules + assert "datajoint.cli" not in sys.modules