@@ -37,7 +37,7 @@ class QueryExpression:
37
37
_restriction = None
38
38
_restriction_attributes = None
39
39
_left = [] # True for left joins, False for inner joins
40
- _join_attributes = []
40
+ _original_heading = None # heading before projections
41
41
42
42
# subclasses or instantiators must provide values
43
43
_connection = None
@@ -61,6 +61,11 @@ def heading(self):
61
61
""" a dj.Heading object, reflects the effects of the projection operator .proj """
62
62
return self ._heading
63
63
64
+ @property
65
+ def original_heading (self ):
66
+ """ a dj.Heading object reflecting the attributes before projection """
67
+ return self ._original_heading or self .heading
68
+
64
69
@property
65
70
def restriction (self ):
66
71
""" a AndList object of restrictions applied to input to produce the result """
@@ -85,11 +90,10 @@ def from_clause(self):
85
90
support = ('(' + src .make_sql () + ') as `_s%x`' % next (
86
91
self ._subquery_alias_count ) if isinstance (src , QueryExpression ) else src for src in self .support )
87
92
clause = next (support )
88
- for s , a , left in zip (support , self . _join_attributes , self ._left ):
89
- clause += '{left} JOIN {clause}{using }' .format (
93
+ for s , left in zip (support , self ._left ):
94
+ clause += 'NATURAL {left} JOIN {clause}' .format (
90
95
left = " LEFT" if left else "" ,
91
- clause = s ,
92
- using = "" if not a else " USING (%s)" % "," .join ('`%s`' % _ for _ in a ))
96
+ clause = s )
93
97
return clause
94
98
95
99
def where_clause (self ):
@@ -241,34 +245,29 @@ def join(self, other, semantic_check=True, left=False):
241
245
other = other () # instantiate
242
246
if not isinstance (other , QueryExpression ):
243
247
raise DataJointError ("The argument of join must be a QueryExpression" )
244
- other_clash = set (other .heading .names ) | set (
245
- (other .heading [n ].attribute_expression .strip ('`' ) for n in other .heading .new_attributes ))
246
- self_clash = set (self .heading .names ) | set (
247
- (self .heading [n ].attribute_expression for n in self .heading .new_attributes ))
248
- need_subquery1 = isinstance (self , Union ) or any (
249
- n for n in self .heading .new_attributes if (
250
- n in other_clash or self .heading [n ].attribute_expression .strip ('`' ) in other_clash ))
251
- need_subquery2 = (len (other .support ) > 1 or
252
- isinstance (self , Union ) or any (
253
- n for n in other .heading .new_attributes if (
254
- n in self_clash or other .heading [n ].attribute_expression .strip ('`' ) in other_clash )))
248
+ if semantic_check :
249
+ assert_join_compatibility (self , other )
250
+ join_attributes = set (n for n in self .heading .names if n in other .heading .names )
251
+ # needs subquery if FROM class has common attributes with the other's FROM clause
252
+ need_subquery1 = need_subquery2 = bool (
253
+ (set (self .original_heading .names ) & set (other .original_heading .names ))
254
+ - join_attributes )
255
+ # need subquery if any of the join attributes are derived
256
+ need_subquery1 = need_subquery1 or any (n in self .heading .new_attributes for n in join_attributes )
257
+ need_subquery2 = need_subquery2 or any (n in other .heading .new_attributes for n in join_attributes )
255
258
if need_subquery1 :
256
259
self = self .make_subquery ()
257
260
if need_subquery2 :
258
261
other = other .make_subquery ()
259
- if semantic_check :
260
- assert_join_compatibility (self , other )
261
262
result = QueryExpression ()
262
263
result ._connection = self .connection
263
264
result ._support = self .support + other .support
264
- result ._join_attributes = (
265
- self ._join_attributes + [[a for a in self .heading .names if a in other .heading .names ]] +
266
- other ._join_attributes )
267
265
result ._left = self ._left + [left ] + other ._left
268
266
result ._heading = self .heading .join (other .heading )
269
267
result ._restriction = AndList (self .restriction )
270
268
result ._restriction .append (other .restriction )
271
- assert len (result .support ) == len (result ._join_attributes ) + 1 == len (result ._left ) + 1
269
+ result ._original_heading = self .original_heading .join (other .original_heading )
270
+ assert len (result .support ) == len (result ._left ) + 1
272
271
return result
273
272
274
273
def __add__ (self , other ):
@@ -371,6 +370,7 @@ def proj(self, *attributes, **named_attributes):
371
370
need_subquery = any (name in self .restriction_attributes for name in self .heading .new_attributes )
372
371
373
372
result = self .make_subquery () if need_subquery else copy .copy (self )
373
+ result ._original_heading = result .original_heading
374
374
result ._heading = result .heading .select (
375
375
attributes , rename_map = dict (** rename_map , ** replicate_map ), compute_map = compute_map )
376
376
return result
@@ -525,7 +525,6 @@ def create(cls, arg, group, keep_all_rows=False):
525
525
result ._connection = join .connection
526
526
result ._heading = join .heading .set_primary_key (arg .primary_key ) # use left operand's primary key
527
527
result ._support = join .support
528
- result ._join_attributes = join ._join_attributes
529
528
result ._left = join ._left
530
529
result ._left_restrict = join .restriction # WHERE clause applied before GROUP BY
531
530
result ._grouping_attributes = result .primary_key
0 commit comments