Skip to content

Commit b8cf4b4

Browse files
committed
Fix bug with unions.
1 parent 7764467 commit b8cf4b4

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

datajoint/expression.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -625,9 +625,15 @@ def make_sql(self):
625625
if not arg1.heading.secondary_attributes and not arg2.heading.secondary_attributes:
626626
# no secondary attributes: use UNION DISTINCT
627627
fields = arg1.primary_key
628-
return "({sql1}) UNION ({sql2})".format(
629-
sql1=arg1.make_sql(fields),
630-
sql2=arg2.make_sql(fields))
628+
if isinstance(arg1, Union):
629+
sql1 = arg1.make_sql()
630+
else:
631+
sql1 = arg1.make_sql(fields)
632+
if isinstance(arg2, Union):
633+
sql2 = arg2.make_sql()
634+
else:
635+
sql2 = arg2.make_sql(fields)
636+
return "({sql1}) UNION ({sql2})".format(sql1=sql1, sql2=sql2)
631637
# with secondary attributes, use union of left join with antijoin
632638
fields = self.heading.names
633639
sql1 = arg1.join(arg2, left=True).make_sql(fields)

tests/test_relational_operand.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from os import stat
12
import random
23
import string
34
import pandas
@@ -505,3 +506,12 @@ def test_joins_with_aggregation():
505506
session_dates = ((SessionDateA * (subj_query & 'date_trained<"2020-12-21"')) &
506507
'session_date<date_trained')
507508
assert len(session_dates) == 1
509+
510+
@staticmethod
511+
def test_union_multiple():
512+
# https://github.com/datajoint/datajoint-python/issues/926
513+
q1 = IJ & dict(j=2)
514+
q2 = (IJ & dict(j=2, i=0)) + (IJ & dict(j=2, i=1)) + (IJ & dict(j=2, i=2))
515+
x = set(zip(*q1.fetch('i', 'j')))
516+
y = set(zip(*q2.fetch('i', 'j')))
517+
assert_set_equal(x, y)

0 commit comments

Comments
 (0)