@@ -44,6 +44,9 @@ class QueryExpression:
4444 _heading = None
4545 _support = None
4646
47+ # If the query will be using distinct
48+ _distinct = False
49+
4750 @property
4851 def connection (self ):
4952 """ a dj.Connection object """
@@ -106,9 +109,8 @@ def make_sql(self, fields=None):
106109 Make the SQL SELECT statement.
107110 :param fields: used to explicitly set the select attributes
108111 """
109- distinct = self .heading .names == self .primary_key
110112 return 'SELECT {distinct}{fields} FROM {from_}{where}' .format (
111- distinct = "DISTINCT " if distinct else "" ,
113+ distinct = "DISTINCT " if self . _distinct else "" ,
112114 fields = self .heading .as_sql (fields or self .heading .names ),
113115 from_ = self .from_clause (), where = self .where_clause ())
114116
@@ -266,9 +268,11 @@ def join(self, other, semantic_check=True, left=False):
266268 - join_attributes )
267269 # need subquery if any of the join attributes are derived
268270 need_subquery1 = (need_subquery1 or isinstance (self , Aggregation ) or
269- any (n in self .heading .new_attributes for n in join_attributes ))
271+ any (n in self .heading .new_attributes for n in join_attributes )
272+ or isinstance (self , Union ))
270273 need_subquery2 = (need_subquery2 or isinstance (other , Aggregation ) or
271- any (n in other .heading .new_attributes for n in join_attributes ))
274+ any (n in other .heading .new_attributes for n in join_attributes )
275+ or isinstance (self , Union ))
272276 if need_subquery1 :
273277 self = self .make_subquery ()
274278 if need_subquery2 :
@@ -440,8 +444,10 @@ def tail(self, limit=25, **fetch_kwargs):
440444 def __len__ (self ):
441445 """:return: number of elements in the result set e.g. ``len(q1)``."""
442446 return self .connection .query (
443- 'SELECT count(DISTINCT {fields}) FROM {from_}{where}' .format (
444- fields = self .heading .as_sql (self .primary_key , include_aliases = False ),
447+ 'SELECT {select_} FROM {from_}{where}' .format (
448+ select_ = ('count(*)' if any (self ._left )
449+ else 'count(DISTINCT {fields})' .format (fields = self .heading .as_sql (
450+ self .primary_key , include_aliases = False ))),
445451 from_ = self .from_clause (),
446452 where = self .where_clause ())).fetchone ()[0 ]
447453
@@ -554,7 +560,7 @@ def create(cls, arg, group, keep_all_rows=False):
554560 if inspect .isclass (group ) and issubclass (group , QueryExpression ):
555561 group = group () # instantiate if a class
556562 assert isinstance (group , QueryExpression )
557- if keep_all_rows and len (group .support ) > 1 :
563+ if keep_all_rows and len (group .support ) > 1 or group . heading . new_attributes :
558564 group = group .make_subquery () # subquery if left joining a join
559565 join = arg .join (group , left = keep_all_rows ) # reuse the join logic
560566 result = cls ()
@@ -718,6 +724,7 @@ def __and__(self, other):
718724 if not isinstance (other , QueryExpression ):
719725 raise DataJointError ('Set U can only be restricted with a QueryExpression.' )
720726 result = copy .copy (other )
727+ result ._distinct = True
721728 result ._heading = result .heading .set_primary_key (self .primary_key )
722729 result = result .proj ()
723730 return result
0 commit comments