@@ -37,7 +37,7 @@ class QueryExpression:
3737 _restriction = None
3838 _restriction_attributes = None
3939 _left = [] # True for left joins, False for inner joins
40- _join_attributes = []
40+ _original_heading = None # heading before projections
4141
4242 # subclasses or instantiators must provide values
4343 _connection = None
@@ -61,6 +61,11 @@ def heading(self):
6161 """ a dj.Heading object, reflects the effects of the projection operator .proj """
6262 return self ._heading
6363
64+ @property
65+ def original_heading (self ):
66+ """ a dj.Heading object reflecting the attributes before projection """
67+ return self ._original_heading or self .heading
68+
6469 @property
6570 def restriction (self ):
6671 """ a AndList object of restrictions applied to input to produce the result """
@@ -85,11 +90,10 @@ def from_clause(self):
8590 support = ('(' + src .make_sql () + ') as `_s%x`' % next (
8691 self ._subquery_alias_count ) if isinstance (src , QueryExpression ) else src for src in self .support )
8792 clause = next (support )
88- for s , a , left in zip (support , self . _join_attributes , self ._left ):
89- clause += '{left} JOIN {clause}{using }' .format (
93+ for s , left in zip (support , self ._left ):
94+ clause += 'NATURAL {left} JOIN {clause}' .format (
9095 left = " LEFT" if left else "" ,
91- clause = s ,
92- using = "" if not a else " USING (%s)" % "," .join ('`%s`' % _ for _ in a ))
96+ clause = s )
9397 return clause
9498
9599 def where_clause (self ):
@@ -241,34 +245,29 @@ def join(self, other, semantic_check=True, left=False):
241245 other = other () # instantiate
242246 if not isinstance (other , QueryExpression ):
243247 raise DataJointError ("The argument of join must be a QueryExpression" )
244- other_clash = set (other .heading .names ) | set (
245- (other .heading [n ].attribute_expression .strip ('`' ) for n in other .heading .new_attributes ))
246- self_clash = set (self .heading .names ) | set (
247- (self .heading [n ].attribute_expression for n in self .heading .new_attributes ))
248- need_subquery1 = isinstance (self , Union ) or any (
249- n for n in self .heading .new_attributes if (
250- n in other_clash or self .heading [n ].attribute_expression .strip ('`' ) in other_clash ))
251- need_subquery2 = (len (other .support ) > 1 or
252- isinstance (self , Union ) or any (
253- n for n in other .heading .new_attributes if (
254- n in self_clash or other .heading [n ].attribute_expression .strip ('`' ) in other_clash )))
248+ if semantic_check :
249+ assert_join_compatibility (self , other )
250+ join_attributes = set (n for n in self .heading .names if n in other .heading .names )
251+ # needs subquery if FROM class has common attributes with the other's FROM clause
252+ need_subquery1 = need_subquery2 = bool (
253+ (set (self .original_heading .names ) & set (other .original_heading .names ))
254+ - join_attributes )
255+ # need subquery if any of the join attributes are derived
256+ need_subquery1 = need_subquery1 or any (n in self .heading .new_attributes for n in join_attributes )
257+ need_subquery2 = need_subquery2 or any (n in other .heading .new_attributes for n in join_attributes )
255258 if need_subquery1 :
256259 self = self .make_subquery ()
257260 if need_subquery2 :
258261 other = other .make_subquery ()
259- if semantic_check :
260- assert_join_compatibility (self , other )
261262 result = QueryExpression ()
262263 result ._connection = self .connection
263264 result ._support = self .support + other .support
264- result ._join_attributes = (
265- self ._join_attributes + [[a for a in self .heading .names if a in other .heading .names ]] +
266- other ._join_attributes )
267265 result ._left = self ._left + [left ] + other ._left
268266 result ._heading = self .heading .join (other .heading )
269267 result ._restriction = AndList (self .restriction )
270268 result ._restriction .append (other .restriction )
271- assert len (result .support ) == len (result ._join_attributes ) + 1 == len (result ._left ) + 1
269+ result ._original_heading = self .original_heading .join (other .original_heading )
270+ assert len (result .support ) == len (result ._left ) + 1
272271 return result
273272
274273 def __add__ (self , other ):
@@ -371,6 +370,7 @@ def proj(self, *attributes, **named_attributes):
371370 need_subquery = any (name in self .restriction_attributes for name in self .heading .new_attributes )
372371
373372 result = self .make_subquery () if need_subquery else copy .copy (self )
373+ result ._original_heading = result .original_heading
374374 result ._heading = result .heading .select (
375375 attributes , rename_map = dict (** rename_map , ** replicate_map ), compute_map = compute_map )
376376 return result
@@ -525,7 +525,6 @@ def create(cls, arg, group, keep_all_rows=False):
525525 result ._connection = join .connection
526526 result ._heading = join .heading .set_primary_key (arg .primary_key ) # use left operand's primary key
527527 result ._support = join .support
528- result ._join_attributes = join ._join_attributes
529528 result ._left = join ._left
530529 result ._left_restrict = join .restriction # WHERE clause applied before GROUP BY
531530 result ._grouping_attributes = result .primary_key
0 commit comments