diff --git a/datajoint/__init__.py b/datajoint/__init__.py index a7c5e7b2f..d077b76da 100644 --- a/datajoint/__init__.py +++ b/datajoint/__init__.py @@ -54,25 +54,88 @@ "logger", "cli", ] +import importlib +from typing import TYPE_CHECKING -from . import errors -from .admin import kill, set_password -from .attribute_adapter import AttributeAdapter -from .blob import MatCell, MatStruct -from .cli import cli -from .connection import Connection, conn -from .diagram import Diagram +from . import errors from .errors import DataJointError -from .expression import AndList, Not, Top, U -from .fetch import key -from .hash import key_hash -from .logging import logger -from .schemas import Schema, VirtualModule, list_schemas -from .settings import config -from .table import FreeTable, Table -from .user_tables import Computed, Imported, Lookup, Manual, Part -from .version import __version__ - -ERD = Di = Diagram # Aliases for Diagram -schema = Schema # Aliases for Schema -create_virtual_module = VirtualModule # Aliases for VirtualModule +from .logging import logger +from .settings import config +from .version import __version__ + +from .connection import Connection, conn + +if TYPE_CHECKING: + from .admin import kill, set_password + from .attribute_adapter import AttributeAdapter + from .blob import MatCell, MatStruct + from .cli import cli + from .diagram import Diagram + from .expression import AndList, Not, Top, U + from .fetch import key + from .hash import key_hash + from .schemas import Schema, VirtualModule, list_schemas + from .table import FreeTable, Table + from .user_tables import Computed, Imported, Lookup, Manual, Part + + +_LAZY: dict[str, tuple[str, str]] = { + # admin + "kill": ("datajoint.admin", "kill"), + "set_password": ("datajoint.admin", "set_password"), + + # core objects + "Schema": ("datajoint.schemas", "Schema"), + "VirtualModule": ("datajoint.schemas", "VirtualModule"), + "list_schemas": ("datajoint.schemas", "list_schemas"), + + # tables + "Table": ("datajoint.table", "Table"), + "FreeTable": ("datajoint.table", "FreeTable"), + "Manual": ("datajoint.user_tables", "Manual"), + "Lookup": ("datajoint.user_tables", "Lookup"), + "Imported": ("datajoint.user_tables", "Imported"), + "Computed": ("datajoint.user_tables", "Computed"), + "Part": ("datajoint.user_tables", "Part"), + + # diagram + "Diagram": ("datajoint.diagram", "Diagram"), + + # expressions + "Not": ("datajoint.expression", "Not"), + "AndList": ("datajoint.expression", "AndList"), + "Top": ("datajoint.expression", "Top"), + "U": ("datajoint.expression", "U"), + + # misc utilities + "MatCell": ("datajoint.blob", "MatCell"), + "MatStruct": ("datajoint.blob", "MatStruct"), + "AttributeAdapter": ("datajoint.attribute_adapter", "AttributeAdapter"), + "key": ("datajoint.fetch", "key"), + "key_hash": ("datajoint.hash", "key_hash"), + "cli": ("datajoint.cli", "cli"), +} +_ALIAS: dict[str, str] = { + "ERD": "Diagram", + "Di": "Diagram", + "schema": "Schema", + "create_virtual_module": "VirtualModule", +} + + +def __getattr__(name: str): + if name in _ALIAS: + target = _ALIAS[name] + value = getattr(importlib.import_module(_LAZY[target][0]), _LAZY[target][1]) + globals()[target] = value + globals()[name] = value + return value + + if name in _LAZY: + module_name, attr = _LAZY[name] + module = importlib.import_module(module_name) + value = getattr(module, attr) + globals()[name] = value # cache + return value + + raise AttributeError(f"module 'datajoint' has no attribute {name}") \ No newline at end of file diff --git a/datajoint/user_tables.py b/datajoint/user_tables.py index 9c2e79d34..4d708fe5e 100644 --- a/datajoint/user_tables.py +++ b/datajoint/user_tables.py @@ -219,16 +219,17 @@ def table_name(cls): else cls.master.table_name + "__" + from_camel_case(cls.__name__) ) - def delete(self, force=False): + def delete(self, force=False, **kwargs): """ - unless force is True, prohibits direct deletes from parts. + Unless force is True, prohibits direct deletes from parts. + Accepts any kwargs supported by Table.delete and forwards them to super().delete. """ if force: - super().delete(force_parts=True) - else: - raise DataJointError( - "Cannot delete from a Part directly. Delete from master instead" - ) + return super().delete(force_parts=True, **kwargs) + + raise DataJointError( + "Cannot delete from a Part directly. Delete from master instead" + ) def drop(self, force=False): """ diff --git a/tests/conftest.py b/tests/conftest.py index 88d55e32f..3b9e8b1ac 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,9 +6,10 @@ from typing import Dict, List import certifi -import minio import networkx as nx import pytest + +minio = pytest.importorskip("minio") import urllib3 from packaging import version diff --git a/tests/test_cascading_delete.py b/tests/test_cascading_delete.py index 71216fcb2..492327634 100644 --- a/tests/test_cascading_delete.py +++ b/tests/test_cascading_delete.py @@ -37,6 +37,16 @@ def test_stepwise_delete(schema_simp_pop): ), "failed to delete from the parent table following child table deletion" +def test_part_delete_forwards_kwargs(schema_simp_pop): + assert not dj.config["safemode"], "safemode must be off for testing" + assert L() and A() and B() and B.C(), "schema population failed" + + # Should accept and forward kwargs supported by Table.delete + B.C().delete(force=True, transaction=False) + + assert not B.C(), "failed to delete child table with forwarded kwargs" + + def test_delete_tree_restricted(schema_simp_pop): assert not dj.config["safemode"], "safemode must be off for testing" assert ( diff --git a/tests/test_import_performance.py b/tests/test_import_performance.py new file mode 100644 index 000000000..d5066b520 --- /dev/null +++ b/tests/test_import_performance.py @@ -0,0 +1,6 @@ +def test_import_does_not_eager_load_heavy_deps(): + import sys + import datajoint # noqa: F401 + + assert "datajoint.diagram" not in sys.modules + assert "pandas" not in sys.modules diff --git a/tests/test_relational_operand.py b/tests/test_relational_operand.py index 2dbea672e..dd36483e5 100644 --- a/tests/test_relational_operand.py +++ b/tests/test_relational_operand.py @@ -630,8 +630,8 @@ def test_top_restriction_with_keywords(self, schema_simp_pop): ] assert key.fetch(as_dict=True) == [ {"id": 2, "key": 6}, - {"id": 2, "key": 5}, {"id": 1, "key": 5}, + {"id": 2, "key": 5}, {"id": 0, "key": 4}, {"id": 1, "key": 4}, {"id": 2, "key": 4},