Skip to content

Commit bc8e3bd

Browse files
fix: handle tier prefixes in Schema.get_table and __contains__
The get_table(), __getitem__, and __contains__ methods now auto-detect table tier prefixes (Manual: none, Lookup: #, Imported: _, Computed: __). This allows users to access tables by their base name without knowing the tier prefix: - schema.get_table("experiment") finds "_experiment" (Imported) - schema["Subject"] finds "#subject" (Lookup) - "Experiment" in schema returns True Added _find_table_name() helper that checks exact match first, then tries each tier prefix. Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent 405f10e commit bc8e3bd

File tree

1 file changed

+35
-10
lines changed

1 file changed

+35
-10
lines changed

src/datajoint/schemas.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,33 @@ def list_tables(self) -> list[str]:
628628
if d == self.database
629629
]
630630

631+
def _find_table_name(self, name: str) -> str | None:
632+
"""
633+
Find the actual SQL table name for a given base name.
634+
635+
Handles tier prefixes: Manual (none), Lookup (#), Imported (_), Computed (__).
636+
637+
Parameters
638+
----------
639+
name : str
640+
Base table name without tier prefix.
641+
642+
Returns
643+
-------
644+
str or None
645+
The actual SQL table name, or None if not found.
646+
"""
647+
tables = self.list_tables()
648+
# Check exact match first
649+
if name in tables:
650+
return name
651+
# Check with tier prefixes
652+
for prefix in ("", "#", "_", "__"):
653+
candidate = f"{prefix}{name}"
654+
if candidate in tables:
655+
return candidate
656+
return None
657+
631658
def get_table(self, name: str) -> FreeTable:
632659
"""
633660
Get a table instance by name.
@@ -640,6 +667,7 @@ def get_table(self, name: str) -> FreeTable:
640667
name : str
641668
Table name (e.g., 'experiment', 'session__trial' for parts).
642669
Can be snake_case (SQL name) or CamelCase (class name).
670+
Tier prefixes are optional and will be auto-detected.
643671
644672
Returns
645673
-------
@@ -659,17 +687,15 @@ def get_table(self, name: str) -> FreeTable:
659687
"""
660688
self._assert_exists()
661689
# Convert CamelCase to snake_case if needed
662-
import re
663-
664690
if name[0].isupper():
665-
# CamelCase to snake_case conversion
666691
name = re.sub(r"(?<!^)(?=[A-Z])", "_", name).lower()
667692

668-
full_name = f"`{self.database}`.`{name}`"
669-
table = FreeTable(self.connection, full_name)
670-
if not table.is_declared:
693+
table_name = self._find_table_name(name)
694+
if table_name is None:
671695
raise DataJointError(f"Table `{name}` does not exist in schema `{self.database}`.")
672-
return table
696+
697+
full_name = f"`{self.database}`.`{table_name}`"
698+
return FreeTable(self.connection, full_name)
673699

674700
def __getitem__(self, name: str) -> FreeTable:
675701
"""
@@ -721,6 +747,7 @@ def __contains__(self, name: str) -> bool:
721747
----------
722748
name : str
723749
Table name (snake_case or CamelCase).
750+
Tier prefixes are optional and will be auto-detected.
724751
725752
Returns
726753
-------
@@ -732,11 +759,9 @@ def __contains__(self, name: str) -> bool:
732759
>>> 'Experiment' in schema
733760
True
734761
"""
735-
import re
736-
737762
if name[0].isupper():
738763
name = re.sub(r"(?<!^)(?=[A-Z])", "_", name).lower()
739-
return name in self.list_tables()
764+
return self._find_table_name(name) is not None
740765

741766

742767
class VirtualModule(types.ModuleType):

0 commit comments

Comments
 (0)