|
9 | 9 | from .preview import preview, repr_html |
10 | 10 | from .condition import AndList, Not, \ |
11 | 11 | make_condition, assert_join_compatibility, extract_column_names, PromiscuousOperand |
| 12 | +from .declare import CONSTANT_LITERALS |
12 | 13 |
|
13 | 14 | logger = logging.getLogger(__name__) |
14 | 15 |
|
@@ -44,6 +45,9 @@ class QueryExpression: |
44 | 45 | _heading = None |
45 | 46 | _support = None |
46 | 47 |
|
| 48 | + # If the query will be using distinct |
| 49 | + _distinct = False |
| 50 | + |
47 | 51 | @property |
48 | 52 | def connection(self): |
49 | 53 | """ a dj.Connection object """ |
@@ -106,9 +110,8 @@ def make_sql(self, fields=None): |
106 | 110 | Make the SQL SELECT statement. |
107 | 111 | :param fields: used to explicitly set the select attributes |
108 | 112 | """ |
109 | | - distinct = self.heading.names == self.primary_key |
110 | 113 | return 'SELECT {distinct}{fields} FROM {from_}{where}'.format( |
111 | | - distinct="DISTINCT " if distinct else "", |
| 114 | + distinct="DISTINCT " if self._distinct else "", |
112 | 115 | fields=self.heading.as_sql(fields or self.heading.names), |
113 | 116 | from_=self.from_clause(), where=self.where_clause()) |
114 | 117 |
|
@@ -266,9 +269,11 @@ def join(self, other, semantic_check=True, left=False): |
266 | 269 | - join_attributes) |
267 | 270 | # need subquery if any of the join attributes are derived |
268 | 271 | need_subquery1 = (need_subquery1 or isinstance(self, Aggregation) or |
269 | | - any(n in self.heading.new_attributes for n in join_attributes)) |
| 272 | + any(n in self.heading.new_attributes for n in join_attributes) |
| 273 | + or isinstance(self, Union)) |
270 | 274 | need_subquery2 = (need_subquery2 or isinstance(other, Aggregation) or |
271 | | - any(n in other.heading.new_attributes for n in join_attributes)) |
| 275 | + any(n in other.heading.new_attributes for n in join_attributes) |
| 276 | + or isinstance(self, Union)) |
272 | 277 | if need_subquery1: |
273 | 278 | self = self.make_subquery() |
274 | 279 | if need_subquery2: |
@@ -309,9 +314,9 @@ def proj(self, *attributes, **named_attributes): |
309 | 314 | Each attribute name can only be used once. |
310 | 315 | """ |
311 | 316 | # new attributes in parentheses are included again with the new name without removing original |
312 | | - duplication_pattern = re.compile(r'\s*\(\s*(?P<name>[a-z][a-z_0-9]*)\s*\)\s*$') |
| 317 | + duplication_pattern = re.compile(fr'^\s*\(\s*(?!{"|".join(CONSTANT_LITERALS)})(?P<name>[a-zA-Z_]\w*)\s*\)\s*$') |
313 | 318 | # attributes without parentheses renamed |
314 | | - rename_pattern = re.compile(r'\s*(?P<name>[a-z][a-z_0-9]*)\s*$') |
| 319 | + rename_pattern = re.compile(fr'^\s*(?!{"|".join(CONSTANT_LITERALS)})(?P<name>[a-zA-Z_]\w*)\s*$') |
315 | 320 | replicate_map = {k: m.group('name') |
316 | 321 | for k, m in ((k, duplication_pattern.match(v)) for k, v in named_attributes.items()) if m} |
317 | 322 | rename_map = {k: m.group('name') |
@@ -440,8 +445,10 @@ def tail(self, limit=25, **fetch_kwargs): |
440 | 445 | def __len__(self): |
441 | 446 | """:return: number of elements in the result set e.g. ``len(q1)``.""" |
442 | 447 | return self.connection.query( |
443 | | - 'SELECT count(DISTINCT {fields}) FROM {from_}{where}'.format( |
444 | | - fields=self.heading.as_sql(self.primary_key, include_aliases=False), |
| 448 | + 'SELECT {select_} FROM {from_}{where}'.format( |
| 449 | + select_=('count(*)' if any(self._left) |
| 450 | + else 'count(DISTINCT {fields})'.format(fields=self.heading.as_sql( |
| 451 | + self.primary_key, include_aliases=False))), |
445 | 452 | from_=self.from_clause(), |
446 | 453 | where=self.where_clause())).fetchone()[0] |
447 | 454 |
|
@@ -554,7 +561,7 @@ def create(cls, arg, group, keep_all_rows=False): |
554 | 561 | if inspect.isclass(group) and issubclass(group, QueryExpression): |
555 | 562 | group = group() # instantiate if a class |
556 | 563 | assert isinstance(group, QueryExpression) |
557 | | - if keep_all_rows and len(group.support) > 1: |
| 564 | + if keep_all_rows and len(group.support) > 1 or group.heading.new_attributes: |
558 | 565 | group = group.make_subquery() # subquery if left joining a join |
559 | 566 | join = arg.join(group, left=keep_all_rows) # reuse the join logic |
560 | 567 | result = cls() |
@@ -718,6 +725,7 @@ def __and__(self, other): |
718 | 725 | if not isinstance(other, QueryExpression): |
719 | 726 | raise DataJointError('Set U can only be restricted with a QueryExpression.') |
720 | 727 | result = copy.copy(other) |
| 728 | + result._distinct = True |
721 | 729 | result._heading = result.heading.set_primary_key(self.primary_key) |
722 | 730 | result = result.proj() |
723 | 731 | return result |
|
0 commit comments