Skip to content

Commit 7d855b7

Browse files
Merge pull request #1301 from datajoint/claude/semantic-match
Implement Semantic Joins
2 parents c3fbf35 + 19cde1c commit 7d855b7

File tree

12 files changed

+1468
-96
lines changed

12 files changed

+1468
-96
lines changed

docs/src/design/semantic-matching-spec.md

Lines changed: 540 additions & 0 deletions
Large diffs are not rendered by default.

src/datajoint/condition.py

Lines changed: 55 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import decimal
66
import inspect
77
import json
8+
import logging
89
import re
910
import uuid
1011
from dataclasses import dataclass
@@ -14,6 +15,8 @@
1415

1516
from .errors import DataJointError
1617

18+
logger = logging.getLogger(__name__.split(".")[0])
19+
1720
JSON_PATTERN = re.compile(r"^(?P<attr>\w+)(\.(?P<path>[\w.*\[\]]+))?(:(?P<type>[\w(,\s)]+))?$")
1821

1922

@@ -95,39 +98,68 @@ def __init__(self, restriction):
9598
self.restriction = restriction
9699

97100

98-
def assert_join_compatibility(expr1, expr2):
101+
def assert_join_compatibility(expr1, expr2, semantic_check=True):
99102
"""
100-
Determine if expressions expr1 and expr2 are join-compatible. To be join-compatible,
101-
the matching attributes in the two expressions must be in the primary key of one or the
102-
other expression.
103-
Raises an exception if not compatible.
103+
Determine if expressions expr1 and expr2 are join-compatible.
104+
105+
With semantic_check=True (default):
106+
Raises an error if there are non-homologous namesakes (same name, different lineage).
107+
This prevents accidental joins on attributes that share names but represent
108+
different entities.
109+
110+
If the ~lineage table doesn't exist for either schema, a warning is issued
111+
and semantic checking is disabled (join proceeds as natural join).
112+
113+
With semantic_check=False:
114+
No lineage checking. All namesake attributes are matched (natural join behavior).
104115
105116
:param expr1: A QueryExpression object
106117
:param expr2: A QueryExpression object
118+
:param semantic_check: If True (default), use semantic matching and error on conflicts
107119
"""
108120
from .expression import QueryExpression, U
109121

110122
for rel in (expr1, expr2):
111123
if not isinstance(rel, (U, QueryExpression)):
112124
raise DataJointError("Object %r is not a QueryExpression and cannot be joined." % rel)
113-
if not isinstance(expr1, U) and not isinstance(expr2, U): # dj.U is always compatible
114-
try:
115-
raise DataJointError(
116-
"Cannot join query expressions on dependent attribute `%s`"
117-
% next(r for r in set(expr1.heading.secondary_attributes).intersection(expr2.heading.secondary_attributes))
118-
)
119-
except StopIteration:
120-
pass # all ok
121-
122125

123-
def make_condition(query_expression, condition, columns):
126+
# dj.U is always compatible (it represents all possible lineages)
127+
if isinstance(expr1, U) or isinstance(expr2, U):
128+
return
129+
130+
if semantic_check:
131+
# Check if lineage tracking is available for both expressions
132+
if not expr1.heading.lineage_available or not expr2.heading.lineage_available:
133+
logger.warning(
134+
"Semantic check disabled: ~lineage table not found. "
135+
"To enable semantic matching, rebuild lineage with: "
136+
"schema.rebuild_lineage()"
137+
)
138+
return
139+
140+
# Error on non-homologous namesakes
141+
namesakes = set(expr1.heading.names) & set(expr2.heading.names)
142+
for name in namesakes:
143+
lineage1 = expr1.heading[name].lineage
144+
lineage2 = expr2.heading[name].lineage
145+
# Semantic match requires both lineages to be non-None and equal
146+
if lineage1 is None or lineage2 is None or lineage1 != lineage2:
147+
raise DataJointError(
148+
f"Cannot join on attribute `{name}`: "
149+
f"different lineages ({lineage1} vs {lineage2}). "
150+
f"Use .proj() to rename one of the attributes."
151+
)
152+
153+
154+
def make_condition(query_expression, condition, columns, semantic_check=True):
124155
"""
125156
Translate the input condition into the equivalent SQL condition (a string)
126157
127158
:param query_expression: a dj.QueryExpression object to apply condition
128159
:param condition: any valid restriction object.
129160
:param columns: a set passed by reference to collect all column names used in the
130161
condition.
162+
:param semantic_check: If True (default), use semantic matching and error on conflicts.
131163
:return: an SQL condition string or a boolean value.
132164
"""
133165
from .expression import Aggregation, QueryExpression, U
@@ -180,7 +212,11 @@ def combine_conditions(negate, conditions):
180212
# restrict by AndList
181213
if isinstance(condition, AndList):
182214
# omit all conditions that evaluate to True
183-
items = [item for item in (make_condition(query_expression, cond, columns) for cond in condition) if item is not True]
215+
items = [
216+
item
217+
for item in (make_condition(query_expression, cond, columns, semantic_check) for cond in condition)
218+
if item is not True
219+
]
184220
if any(item is False for item in items):
185221
return negate # if any item is False, the whole thing is False
186222
if not items:
@@ -226,14 +262,9 @@ def combine_conditions(negate, conditions):
226262
condition = condition()
227263

228264
# restrict by another expression (aka semijoin and antijoin)
229-
check_compatibility = True
230-
if isinstance(condition, PromiscuousOperand):
231-
condition = condition.operand
232-
check_compatibility = False
233-
234265
if isinstance(condition, QueryExpression):
235-
if check_compatibility:
236-
assert_join_compatibility(query_expression, condition)
266+
assert_join_compatibility(query_expression, condition, semantic_check=semantic_check)
267+
# Always match on all namesakes (natural join semantics)
237268
common_attributes = [q for q in condition.heading.names if q in query_expression.heading.names]
238269
columns.update(common_attributes)
239270
if isinstance(condition, Aggregation):
@@ -255,7 +286,7 @@ def combine_conditions(negate, conditions):
255286

256287
# if iterable (but not a string, a QueryExpression, or an AndList), treat as an OrList
257288
try:
258-
or_list = [make_condition(query_expression, q, columns) for q in condition]
289+
or_list = [make_condition(query_expression, q, columns, semantic_check) for q in condition]
259290
except TypeError:
260291
raise DataJointError("Invalid restriction type %r" % condition)
261292
else:

src/datajoint/declare.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def is_foreign_key(line):
141141
return arrow_position >= 0 and not any(c in line[:arrow_position] for c in "\"#'")
142142

143143

144-
def compile_foreign_key(line, context, attributes, primary_key, attr_sql, foreign_key_sql, index_sql):
144+
def compile_foreign_key(line, context, attributes, primary_key, attr_sql, foreign_key_sql, index_sql, fk_attribute_map=None):
145145
"""
146146
:param line: a line from a table definition
147147
:param context: namespace containing referenced objects
@@ -151,6 +151,7 @@ def compile_foreign_key(line, context, attributes, primary_key, attr_sql, foreig
151151
:param attr_sql: list of sql statements defining attributes -- to be updated by this function.
152152
:param foreign_key_sql: list of sql statements specifying foreign key constraints -- to be updated by this function.
153153
:param index_sql: list of INDEX declaration statements, duplicate or redundant indexes are ok.
154+
:param fk_attribute_map: dict mapping child attr -> (parent_table, parent_attr) -- to be updated by this function.
154155
"""
155156
# Parse and validate
156157
from .expression import QueryExpression
@@ -194,6 +195,11 @@ def compile_foreign_key(line, context, attributes, primary_key, attr_sql, foreig
194195
if primary_key is not None:
195196
primary_key.append(attr)
196197
attr_sql.append(ref.heading[attr].sql.replace("NOT NULL ", "", int(is_nullable)))
198+
# Track FK attribute mapping for lineage: child_attr -> (parent_table, parent_attr)
199+
if fk_attribute_map is not None:
200+
parent_table = ref.support[0] # e.g., `schema`.`table`
201+
parent_attr = ref.heading[attr].original_name
202+
fk_attribute_map[attr] = (parent_table, parent_attr)
197203

198204
# declare the foreign key
199205
foreign_key_sql.append(
@@ -223,6 +229,7 @@ def prepare_declare(definition, context):
223229
foreign_key_sql = []
224230
index_sql = []
225231
external_stores = []
232+
fk_attribute_map = {} # child_attr -> (parent_table, parent_attr)
226233

227234
for line in definition:
228235
if not line or line.startswith("#"): # ignore additional comments
@@ -238,6 +245,7 @@ def prepare_declare(definition, context):
238245
attribute_sql,
239246
foreign_key_sql,
240247
index_sql,
248+
fk_attribute_map,
241249
)
242250
elif re.match(r"^(unique\s+)?index\s*.*$", line, re.I): # index
243251
compile_index(line, index_sql)
@@ -258,6 +266,7 @@ def prepare_declare(definition, context):
258266
foreign_key_sql,
259267
index_sql,
260268
external_stores,
269+
fk_attribute_map,
261270
)
262271

263272

@@ -285,6 +294,7 @@ def declare(full_table_name, definition, context):
285294
foreign_key_sql,
286295
index_sql,
287296
external_stores,
297+
fk_attribute_map,
288298
) = prepare_declare(definition, context)
289299

290300
if config.get("add_hidden_timestamp", False):
@@ -297,11 +307,12 @@ def declare(full_table_name, definition, context):
297307
if not primary_key:
298308
raise DataJointError("Table must have a primary key")
299309

300-
return (
310+
sql = (
301311
"CREATE TABLE IF NOT EXISTS %s (\n" % full_table_name
302312
+ ",\n".join(attribute_sql + ["PRIMARY KEY (`" + "`,`".join(primary_key) + "`)"] + foreign_key_sql + index_sql)
303313
+ '\n) ENGINE=InnoDB, COMMENT "%s"' % table_comment
304-
), external_stores
314+
)
315+
return sql, external_stores, primary_key, fk_attribute_map
305316

306317

307318
def _make_attribute_alter(new, old, primary_key):
@@ -387,6 +398,7 @@ def alter(definition, old_definition, context):
387398
foreign_key_sql,
388399
index_sql,
389400
external_stores,
401+
_fk_attribute_map,
390402
) = prepare_declare(definition, context)
391403
(
392404
table_comment_,
@@ -395,6 +407,7 @@ def alter(definition, old_definition, context):
395407
foreign_key_sql_,
396408
index_sql_,
397409
external_stores_,
410+
_fk_attribute_map_,
398411
) = prepare_declare(old_definition, context)
399412

400413
# analyze differences between declarations

src/datajoint/expression.py

Lines changed: 55 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from .condition import (
88
AndList,
99
Not,
10-
PromiscuousOperand,
1110
Top,
1211
assert_join_compatibility,
1312
extract_column_names,
@@ -152,13 +151,22 @@ def make_subquery(self):
152151
result._heading = self.heading.make_subquery_heading()
153152
return result
154153

155-
def restrict(self, restriction):
154+
def restrict(self, restriction, semantic_check=True):
156155
"""
157156
Produces a new expression with the new restriction applied.
158-
rel.restrict(restriction) is equivalent to rel & restriction.
159-
rel.restrict(Not(restriction)) is equivalent to rel - restriction
157+
158+
:param restriction: a sequence or an array (treated as OR list), another QueryExpression,
159+
an SQL condition string, or an AndList.
160+
:param semantic_check: If True (default), use semantic matching - only match on
161+
homologous namesakes and error on non-homologous namesakes.
162+
If False, use natural matching on all namesakes (no lineage checking).
163+
:return: A new QueryExpression with the restriction applied.
164+
165+
rel.restrict(restriction) is equivalent to rel & restriction.
166+
rel.restrict(Not(restriction)) is equivalent to rel - restriction
167+
160168
The primary key of the result is unaffected.
161-
Successive restrictions are combined as logical AND: r & a & b is equivalent to r & AndList((a, b))
169+
Successive restrictions are combined as logical AND: r & a & b is equivalent to r & AndList((a, b))
162170
Any QueryExpression, collection, or sequence other than an AndList are treated as OrLists
163171
(logical disjunction of conditions)
164172
Inverse restriction is accomplished by either using the subtraction operator or the Not class.
@@ -185,17 +193,14 @@ def restrict(self, restriction):
185193
rel - None rel
186194
rel - any_empty_entity_set rel
187195
188-
When arg is another QueryExpression, the restriction rel & arg restricts rel to elements that match at least
196+
When arg is another QueryExpression, the restriction rel & arg restricts rel to elements that match at least
189197
one element in arg (hence arg is treated as an OrList).
190-
Conversely, rel - arg restricts rel to elements that do not match any elements in arg.
198+
Conversely, rel - arg restricts rel to elements that do not match any elements in arg.
191199
Two elements match when their common attributes have equal values or when they have no common attributes.
192200
All shared attributes must be in the primary key of either rel or arg or both or an error will be raised.
193201
194202
QueryExpression.restrict is the only access point that modifies restrictions. All other operators must
195203
ultimately call restrict()
196-
197-
:param restriction: a sequence or an array (treated as OR list), another QueryExpression, an SQL condition
198-
string, or an AndList.
199204
"""
200205
attributes = set()
201206
if isinstance(restriction, Top):
@@ -204,7 +209,7 @@ def restrict(self, restriction):
204209
) # make subquery to avoid overwriting existing Top
205210
result._top = restriction
206211
return result
207-
new_condition = make_condition(self, restriction, attributes)
212+
new_condition = make_condition(self, restriction, attributes, semantic_check=semantic_check)
208213
if new_condition is True:
209214
return self # restriction has no effect, return the same object
210215
# check that all attributes in condition are present in the query
@@ -240,14 +245,11 @@ def __and__(self, restriction):
240245
return self.restrict(restriction)
241246

242247
def __xor__(self, restriction):
243-
"""
244-
Permissive restriction operator ignoring compatibility check e.g. ``q1 ^ q2``.
245-
"""
246-
if inspect.isclass(restriction) and issubclass(restriction, QueryExpression):
247-
restriction = restriction()
248-
if isinstance(restriction, Not):
249-
return self.restrict(Not(PromiscuousOperand(restriction.restriction)))
250-
return self.restrict(PromiscuousOperand(restriction))
248+
"""The ^ operator has been removed in DataJoint 2.0."""
249+
raise DataJointError(
250+
"The ^ operator has been removed in DataJoint 2.0. "
251+
"Use .restrict(other, semantic_check=False) for restrictions without semantic checking."
252+
)
251253

252254
def __sub__(self, restriction):
253255
"""
@@ -274,30 +276,37 @@ def __mul__(self, other):
274276
return self.join(other)
275277

276278
def __matmul__(self, other):
277-
"""
278-
Permissive join of query expressions `self` and `other` ignoring compatibility check
279-
e.g. ``q1 @ q2``.
280-
"""
281-
if inspect.isclass(other) and issubclass(other, QueryExpression):
282-
other = other() # instantiate
283-
return self.join(other, semantic_check=False)
279+
"""The @ operator has been removed in DataJoint 2.0."""
280+
raise DataJointError(
281+
"The @ operator has been removed in DataJoint 2.0. "
282+
"Use .join(other, semantic_check=False) for joins without semantic checking."
283+
)
284284

285285
def join(self, other, semantic_check=True, left=False):
286286
"""
287-
create the joined QueryExpression.
288-
a * b is short for A.join(B)
289-
a @ b is short for A.join(B, semantic_check=False)
290-
Additionally, left=True will retain the rows of self, effectively performing a left join.
287+
Create the joined QueryExpression.
288+
289+
:param other: QueryExpression to join with
290+
:param semantic_check: If True (default), use semantic matching - only match on
291+
homologous namesakes (same lineage) and error on non-homologous namesakes.
292+
If False, use natural join on all namesakes (no lineage checking).
293+
:param left: If True, perform a left join (retain all rows from self)
294+
:return: The joined QueryExpression
295+
296+
a * b is short for a.join(b)
291297
"""
292-
# trigger subqueries if joining on renamed attributes
298+
# Joining with U is no longer supported
293299
if isinstance(other, U):
294-
return other * self
300+
raise DataJointError(
301+
"table * dj.U(...) is no longer supported in DataJoint 2.0. "
302+
"This pattern is no longer necessary with the new semantic matching system."
303+
)
295304
if inspect.isclass(other) and issubclass(other, QueryExpression):
296305
other = other() # instantiate
297306
if not isinstance(other, QueryExpression):
298307
raise DataJointError("The argument of join must be a QueryExpression")
299-
if semantic_check:
300-
assert_join_compatibility(self, other)
308+
assert_join_compatibility(self, other, semantic_check=semantic_check)
309+
# Always natural join on all namesakes
301310
join_attributes = set(n for n in self.heading.names if n in other.heading.names)
302311
# needs subquery if self's FROM clause has common attributes with other's FROM clause
303312
need_subquery1 = need_subquery2 = bool(
@@ -826,8 +835,18 @@ def join(self, other, left=False):
826835
return result
827836

828837
def __mul__(self, other):
829-
"""shorthand for join"""
830-
return self.join(other)
838+
"""The * operator with dj.U has been removed in DataJoint 2.0."""
839+
raise DataJointError(
840+
"dj.U(...) * table is no longer supported in DataJoint 2.0. "
841+
"This pattern is no longer necessary with the new semantic matching system."
842+
)
843+
844+
def __sub__(self, other):
845+
"""Anti-restriction with dj.U produces an infinite set."""
846+
raise DataJointError(
847+
"dj.U(...) - table produces an infinite set and is not supported. "
848+
"Consider using a different approach for your query."
849+
)
831850

832851
def aggr(self, group, **named_attributes):
833852
"""

0 commit comments

Comments
 (0)