@@ -131,13 +131,26 @@ def get_column_symbols(self):
131131
132132 # characterization
133133
134- def get_tables (self , tables = None ):
134+ def get_tables (self , * , replacements = None ):
135135 """Get a dictionary of all tables used in an operator DAG,
136136 raise an exception if the values are not consistent."""
137- if tables is None :
138- tables = {}
139- for s in self .sources :
140- tables = s .get_tables (tables )
137+ tables = {}
138+ for i in range (len (self .sources )):
139+ s = self .sources [i ]
140+ if isinstance (s , TableDescription ):
141+ if replacements is not None and s .key in replacements :
142+ orig_table = replacements [s .key ]
143+ if s .column_set != orig_table .column_set :
144+ raise ValueError ("table " + s .key + " has two incompatible definitions" )
145+ self .sources [i ] = orig_table
146+ s = orig_table
147+ ti = s .get_tables (replacements = replacements )
148+ for (k , v ) in ti .items ():
149+ if k in tables .keys ():
150+ if not tables [k ] is v :
151+ raise ValueError ("Table " + k + " has two different representation objects" )
152+ else :
153+ tables [k ] = v
141154 return tables
142155
143156 def columns_used_from_sources (self , using = None ):
@@ -533,22 +546,12 @@ def to_python_implementation(self, *, indent=0, strict=True, print_sources=True)
533546 s = s + ")"
534547 return s
535548
536- def get_tables (self , tables = None ):
549+ def get_tables (self , * , replacements = None ):
537550 """get a dictionary of all tables used in an operator DAG,
538551 raise an exception if the values are not consistent"""
539- if tables is None :
540- tables = {}
541- if self .key in tables .keys ():
542- other = tables [self .key ]
543- if self .column_set != other .column_set :
544- raise ValueError (
545- "Two tables with key " + self .key + " have different column sets."
546- )
547- if other is not self :
548- raise ValueError ("Two different table definitions for table: " + self .key )
549- else :
550- tables [self .key ] = self
551- return tables
552+ if replacements is not None and self .key in replacements .keys ():
553+ return {self .key : replacements [self .key ]}
554+ return {self .key : self }
552555
553556 def eval_implementation (self , * , data_map , eval_env , data_model ):
554557 return data_model .table_step (op = self , data_map = data_map , eval_env = eval_env )
@@ -1115,7 +1118,7 @@ class NaturalJoinNode(ViewRepresentation):
11151118
11161119 def __init__ (self , a , b , * , by = None , jointype = "INNER" ):
11171120 a_tables = a .get_tables ()
1118- b_tables = b .get_tables ()
1121+ b_tables = b .get_tables (replacements = a_tables )
11191122 common_keys = set (a_tables .keys ()).intersection (b_tables .keys ())
11201123 for k in common_keys :
11211124 if a_tables [k ] is not b_tables [k ]:
@@ -1195,7 +1198,6 @@ def eval_implementation(self, *, data_map, eval_env, data_model):
11951198
11961199
11971200class ConvertRecordsNode (ViewRepresentation ):
1198- blocks_out_table : TableDescription
11991201
12001202 def __init__ (self , source , record_map , * , blocks_out_table = None ):
12011203 sources = [source ]
@@ -1206,24 +1208,28 @@ def __init__(self, source, record_map, *, blocks_out_table=None):
12061208 + [c for c in record_map .blocks_out .control_table .columns ],
12071209 )
12081210 if blocks_out_table is not None :
1209- sources = sources + [blocks_out_table ]
12101211 # check blocks_out_table is a direct table
12111212 if not isinstance (blocks_out_table , TableDescription ):
12121213 raise TypeError ("expected blocks_out_table to be a data_algebra.data_ops.TableDescription" )
1213- # check it is the exact same definition object if already present
1214+ # ensure table is the exact same definition object if already present
12141215 a_tables = source .get_tables ()
12151216 if blocks_out_table .key in a_tables .keys ():
12161217 a_table = a_tables [blocks_out_table .key ]
1218+ if not a_table .column_set == blocks_out_table .column_set :
1219+ raise ValueError ("blocks_out_table column definition does not match table already in op DAG" )
12171220 if not blocks_out_table is a_table :
1218- raise ValueError ("different definiton object for: " + blocks_out_table .key )
1221+ blocks_out_table = a_table
1222+ # check blocks_out_table is a direct table
1223+ if not isinstance (blocks_out_table , TableDescription ):
1224+ raise TypeError ("expected blocks_out_table to be a data_algebra.data_ops.TableDescription" )
12191225 # check it has at least the columns we expect
12201226 expect = [c for c in record_map .blocks_out .record_keys ] + \
12211227 [c for c in record_map .blocks_out .control_table .columns ]
12221228 unknown = set (expect ) - set (blocks_out_table .column_names )
12231229 if len (unknown ) > 0 :
12241230 raise ValueError ("blocks_out_table missing columns: " + str (unknown ))
1231+ sources = sources + [blocks_out_table ]
12251232 self .record_map = record_map
1226- self .blocks_out_table = blocks_out_table
12271233 unknown = set (self .record_map .columns_needed ) - set (source .column_names )
12281234 if len (unknown ) > 0 :
12291235 raise ValueError ("missing required columns: " + str (unknown ))
@@ -1244,8 +1250,11 @@ def collect_representation_implementation(self, *, pipeline=None, dialect="Pytho
12441250 od ["op" ] = "ConvertRecords"
12451251 od ["record_map" ] = self .record_map .to_simple_obj ()
12461252 od ['blocks_out_table' ] = None
1247- if self .blocks_out_table is not None :
1248- od ['blocks_out_table' ] = self .blocks_out_table .collect_representation (dialect = dialect )[0 ]
1253+ blocks_out_table = None
1254+ if len (self .sources ) > 1 :
1255+ blocks_out_table = self .sources [1 ]
1256+ if blocks_out_table is not None :
1257+ od ['blocks_out_table' ] = blocks_out_table .collect_representation (dialect = dialect )[0 ]
12491258 pipeline .insert (0 , od )
12501259 return self .sources [0 ].collect_representation_implementation (
12511260 pipeline = pipeline , dialect = dialect
@@ -1261,10 +1270,13 @@ def to_python_implementation(self, *, indent=0, strict=True, print_sources=True)
12611270 )
12621271 rm_str = self .record_map .__repr__ ()
12631272 rm_str = re .sub ("\n " , "\n " , rm_str )
1264- s = s + ("convert_record(" + rm_str +
1273+ s = s + "convert_record(" + rm_str
1274+ if len (self .sources ) > 1 :
1275+ s = s + (
12651276 "\n , blocks_out_table=" +
1266- self .blocks_out_table .to_python_implementation (indent = indent + 3 , strict = strict ) +
1267- ")" )
1277+ self .sources [1 ].to_python_implementation (indent = indent + 3 , strict = strict )
1278+ )
1279+ s = s + ")"
12681280 return s
12691281
12701282 def to_sql_implementation (self , db_model , * , using , temp_id_source ):
@@ -1279,7 +1291,7 @@ def to_sql_implementation(self, db_model, *, using, temp_id_source):
12791291 res = db_model .row_recs_to_blocks_query (
12801292 res ,
12811293 record_spec = self .record_map .blocks_out ,
1282- record_view = self .blocks_out_table ,
1294+ record_view = self .sources [ 1 ] ,
12831295 )
12841296 return res
12851297
0 commit comments