Skip to content

Commit b7889d0

Browse files
fix: cache lazy imports correctly and expose diagram module
- Cache lazy imports in globals() to override the submodule that importlib automatically sets on the parent module - Add dj.diagram to lazy modules (returns module for diagram_active access) - Add tests for cli callable and diagram module access Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent 2c51c5c commit b7889d0

File tree

2 files changed

+25
-3
lines changed

2 files changed

+25
-3
lines changed

src/datajoint/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@
9696
"Diagram": (".diagram", "Diagram"),
9797
"Di": (".diagram", "Diagram"),
9898
"ERD": (".diagram", "Diagram"),
99+
"diagram": (".diagram", None), # Return the module itself
99100
# kill imports pymysql via connection
100101
"kill": (".admin", "kill"),
101102
# cli imports click
@@ -110,5 +111,10 @@ def __getattr__(name: str):
110111
import importlib
111112

112113
module = importlib.import_module(module_path, __package__)
113-
return getattr(module, attr_name)
114+
# If attr_name is None, return the module itself
115+
attr = module if attr_name is None else getattr(module, attr_name)
116+
# Cache in module __dict__ to avoid repeated __getattr__ calls
117+
# and to override the submodule that importlib adds automatically
118+
globals()[name] = attr
119+
return attr
114120
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

tests/unit/test_lazy_imports.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,25 @@ def test_lazy_cli_import():
5959
# CLI module should not be loaded yet
6060
assert "datajoint.cli" not in sys.modules, "cli module loaded eagerly"
6161

62-
# Access cli - should trigger lazy load
63-
_ = dj.cli
62+
# Access cli - should trigger lazy load and return the function
63+
cli_func = dj.cli
6464
assert "datajoint.cli" in sys.modules, "cli module not loaded after access"
65+
assert callable(cli_func), "dj.cli should be callable (the cli function)"
66+
67+
68+
def test_diagram_module_access():
69+
"""dj.diagram should return the diagram module for accessing module-level attrs."""
70+
# Remove datajoint from sys.modules to get fresh import
71+
modules_to_remove = [key for key in sys.modules if key.startswith("datajoint")]
72+
for mod in modules_to_remove:
73+
del sys.modules[mod]
74+
75+
import datajoint as dj
76+
77+
# Access dj.diagram should return the module
78+
diagram_module = dj.diagram
79+
assert hasattr(diagram_module, "diagram_active"), "diagram module should have diagram_active"
80+
assert hasattr(diagram_module, "Diagram"), "diagram module should have Diagram class"
6581

6682

6783
def test_diagram_aliases():

0 commit comments

Comments
 (0)