Skip to content

Commit f71fe38

Browse files
Fix join with aggregations.
1 parent ed16dc8 commit f71fe38

File tree

4 files changed

+97
-83
lines changed

4 files changed

+97
-83
lines changed

datajoint/connection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,15 +203,15 @@ def connect(self):
203203
self._conn = client.connect(
204204
init_command=self.init_fun,
205205
sql_mode="NO_ZERO_DATE,NO_ZERO_IN_DATE,ERROR_FOR_DIVISION_BY_ZERO,"
206-
"STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION",
206+
"STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION,ONLY_FULL_GROUP_BY",
207207
charset=config['connection.charset'],
208208
**{k: v for k, v in self.conn_info.items()
209209
if k not in ['ssl_input', 'host_input']})
210210
except client.err.InternalError:
211211
self._conn = client.connect(
212212
init_command=self.init_fun,
213213
sql_mode="NO_ZERO_DATE,NO_ZERO_IN_DATE,ERROR_FOR_DIVISION_BY_ZERO,"
214-
"STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION",
214+
"STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION,ONLY_FULL_GROUP_BY",
215215
charset=config['connection.charset'],
216216
**{k: v for k, v in self.conn_info.items()
217217
if not(k in ['ssl_input', 'host_input'] or

datajoint/expression.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -84,14 +84,15 @@ def restriction_attributes(self):
8484
def primary_key(self):
8585
return self.heading.primary_key
8686

87-
_subquery_alias_count = count() # count for alias names used in from_clause
87+
_subquery_alias_count = count() # count for alias names used in the FROM clause
8888

8989
def from_clause(self):
90-
support = ('(' + src.make_sql() + ') as `_s%x`' % next(
91-
self._subquery_alias_count) if isinstance(src, QueryExpression) else src for src in self.support)
90+
support = ('(' + src.make_sql() + ') as `$%x`' % next(
91+
self._subquery_alias_count) if isinstance(src, QueryExpression)
92+
else src for src in self.support)
9293
clause = next(support)
9394
for s, left in zip(support, self._left):
94-
clause += 'NATURAL{left} JOIN {clause}'.format(
95+
clause += ' NATURAL{left} JOIN {clause}'.format(
9596
left=" LEFT" if left else "",
9697
clause=s)
9798
return clause
@@ -264,8 +265,10 @@ def join(self, other, semantic_check=True, left=False):
264265
(set(self.original_heading.names) & set(other.original_heading.names))
265266
- join_attributes)
266267
# need subquery if any of the join attributes are derived
267-
need_subquery1 = need_subquery1 or any(n in self.heading.new_attributes for n in join_attributes)
268-
need_subquery2 = need_subquery2 or any(n in other.heading.new_attributes for n in join_attributes)
268+
need_subquery1 = (need_subquery1 or isinstance(self, Aggregation) or
269+
any(n in self.heading.new_attributes for n in join_attributes))
270+
need_subquery2 = (need_subquery2 or isinstance(other, Aggregation) or
271+
any(n in other.heading.new_attributes for n in join_attributes))
269272
if need_subquery1:
270273
self = self.make_subquery()
271274
if need_subquery2:
@@ -721,8 +724,9 @@ def __and__(self, other):
721724

722725
def join(self, other, left=False):
723726
"""
724-
Joining U with a query expression has the effect of promoting the attributes of U to the primary key of
725-
the other query expression.
727+
Joining U with a query expression has the effect of promoting the attributes of U to
728+
the primary key of the other query expression.
729+
726730
:param other: the other query expression to join with.
727731
:param left: ignored. dj.U always acts as if left=False
728732
:return: a copy of the other query expression with the primary key extended.
@@ -733,12 +737,14 @@ def join(self, other, left=False):
733737
raise DataJointError('Set U can only be joined with a QueryExpression.')
734738
try:
735739
raise DataJointError(
736-
'Attribute `%s` not found' % next(k for k in self.primary_key if k not in other.heading.names))
740+
'Attribute `%s` not found' % next(k for k in self.primary_key
741+
if k not in other.heading.names))
737742
except StopIteration:
738743
pass # all ok
739744
result = copy.copy(other)
740745
result._heading = result.heading.set_primary_key(
741-
other.primary_key + [k for k in self.primary_key if k not in other.primary_key])
746+
other.primary_key + [k for k in self.primary_key
747+
if k not in other.primary_key])
742748
return result
743749

744750
def __mul__(self, other):

tests/schema.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,3 +379,63 @@ class ComplexChild(dj.Lookup):
379379
definition = '\n'.join(['-> ComplexParent'] + ['child_id_{}: int'.format(i+1)
380380
for i in range(1)])
381381
contents = [tuple(i for i in range(9))]
382+
383+
384+
@schema
385+
class SubjectA(dj.Lookup):
386+
definition = """
387+
subject_id: varchar(32)
388+
---
389+
dob : date
390+
sex : enum('M', 'F', 'U')
391+
"""
392+
contents = [
393+
('mouse1', '2020-09-01', 'M'),
394+
('mouse2', '2020-03-19', 'F'),
395+
('mouse3', '2020-08-23', 'F')
396+
]
397+
398+
399+
@schema
400+
class SessionA(dj.Lookup):
401+
definition = """
402+
-> SubjectA
403+
session_start_time: datetime
404+
---
405+
session_dir='' : varchar(32)
406+
"""
407+
contents = [
408+
('mouse1', '2020-12-01 12:32:34', ''),
409+
('mouse1', '2020-12-02 12:32:34', ''),
410+
('mouse1', '2020-12-03 12:32:34', ''),
411+
('mouse1', '2020-12-04 12:32:34', '')
412+
]
413+
414+
415+
@schema
416+
class SessionStatusA(dj.Lookup):
417+
definition = """
418+
-> SessionA
419+
---
420+
status: enum('in_training', 'trained_1a', 'trained_1b', 'ready4ephys')
421+
"""
422+
contents = [
423+
('mouse1', '2020-12-01 12:32:34', 'in_training'),
424+
('mouse1', '2020-12-02 12:32:34', 'trained_1a'),
425+
('mouse1', '2020-12-03 12:32:34', 'trained_1b'),
426+
('mouse1', '2020-12-04 12:32:34', 'ready4ephys'),
427+
]
428+
429+
430+
@schema
431+
class SessionDateA(dj.Lookup):
432+
definition = """
433+
-> SubjectA
434+
session_date: date
435+
"""
436+
contents = [
437+
('mouse1', '2020-12-01'),
438+
('mouse1', '2020-12-02'),
439+
('mouse1', '2020-12-03'),
440+
('mouse1', '2020-12-04')
441+
]

tests/test_relational_operand.py

Lines changed: 19 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
import datajoint as dj
1111
from .schema_simple import (A, B, D, E, F, L, DataA, DataB, TTestUpdate, IJ, JI,
1212
ReservedWord, OutfitLaunch)
13-
from .schema import Experiment, TTest3, Trial, Ephys, Child, Parent
13+
from .schema import (Experiment, TTest3, Trial, Ephys, Child, Parent, SubjectA, SessionA,
14+
SessionStatusA, SessionDateA)
1415

1516

1617
def setup():
@@ -160,76 +161,6 @@ def test_issue_376():
160161
def test_issue_463():
161162
assert_equal(((A & B) * B).fetch().size, len(A * B))
162163

163-
@staticmethod
164-
def test_issue_898():
165-
# https://github.com/datajoint/datajoint-python/issues/898
166-
schema = dj.schema('djtest_raphael')
167-
168-
@schema
169-
class Subject(dj.Lookup):
170-
definition = """
171-
subject_id: varchar(32)
172-
---
173-
dob : date
174-
sex : enum('M', 'F', 'U')
175-
"""
176-
contents = [
177-
('mouse1', '2020-09-01', 'M'),
178-
('mouse2', '2020-03-19', 'F'),
179-
('mouse3', '2020-08-23', 'F')
180-
]
181-
182-
@schema
183-
class Session(dj.Lookup):
184-
definition = """
185-
-> Subject
186-
session_start_time: datetime
187-
---
188-
session_dir='' : varchar(32)
189-
"""
190-
contents = [
191-
('mouse1', '2020-12-01 12:32:34', ''),
192-
('mouse1', '2020-12-02 12:32:34', ''),
193-
('mouse1', '2020-12-03 12:32:34', ''),
194-
('mouse1', '2020-12-04 12:32:34', '')
195-
]
196-
197-
@schema
198-
class SessionStatus(dj.Lookup):
199-
definition = """
200-
-> Session
201-
---
202-
status: enum('in_training', 'trained_1a', 'trained_1b', 'ready4ephys')
203-
"""
204-
contents = [
205-
('mouse1', '2020-12-01 12:32:34', 'in_training'),
206-
('mouse1', '2020-12-02 12:32:34', 'trained_1a'),
207-
('mouse1', '2020-12-03 12:32:34', 'trained_1b'),
208-
('mouse1', '2020-12-04 12:32:34', 'ready4ephys'),
209-
]
210-
211-
@schema
212-
class SessionDate(dj.Lookup):
213-
definition = """
214-
-> Subject
215-
session_date: date
216-
"""
217-
contents = [
218-
('mouse1', '2020-12-01'),
219-
('mouse1', '2020-12-02'),
220-
('mouse1', '2020-12-03'),
221-
('mouse1', '2020-12-04')
222-
]
223-
224-
subjects = Subject.aggr(
225-
SessionStatus & 'status="trained_1a" or status="trained_1b"',
226-
date_trained='min(date(session_start_time))')
227-
228-
print(f'subjects: {subjects}')
229-
print(f'SessionDate: {SessionDate()}')
230-
print(f'join: {SessionDate * subjects}')
231-
print(f'join query: {(SessionDate * subjects).make_sql()}')
232-
233164
@staticmethod
234165
def test_project():
235166
x = A().proj(a='id_a') # rename
@@ -557,3 +488,20 @@ def test_null_dict_restriction():
557488
assert len(q) == 1
558489
q = F & dict(id=5, date=None)
559490
assert len(q) == 1
491+
492+
@staticmethod
493+
def test_joins_with_aggregation():
494+
# https://github.com/datajoint/datajoint-python/issues/898
495+
# https://github.com/datajoint/datajoint-python/issues/899
496+
subjects = SubjectA.aggr(
497+
SessionStatusA & 'status="trained_1a" or status="trained_1b"',
498+
date_trained='min(date(session_start_time))')
499+
assert len(SessionDateA * subjects) == 4
500+
assert len(subjects * SessionDateA) == 4
501+
502+
subj_query = SubjectA.aggr(
503+
SessionA * SessionStatusA & 'status="trained_1a" or status="trained_1b"',
504+
date_trained='min(date(session_start_time))')
505+
session_dates = ((SessionDateA * (subj_query & 'date_trained<"2020-12-21"')) &
506+
'session_date<"date_trained"')
507+
assert len(session_dates) == 4

0 commit comments

Comments
 (0)