@@ -67,11 +67,20 @@ def from_pyarrow(cls, arrow_table: pa.Table, session: Session):
6767
6868 iobytes = io .BytesIO ()
6969 pa_feather .write_feather (adapted_table , iobytes )
70+ # Scan all columns by default, we define this list as it can be pruned while preserving source_def
71+ scan_list = nodes .ScanList (
72+ tuple (
73+ nodes .ScanItem (ids .ColumnId (item .column ), item .dtype , item .column )
74+ for item in schema .items
75+ )
76+ )
77+
7078 node = nodes .ReadLocalNode (
7179 iobytes .getvalue (),
7280 data_schema = schema ,
7381 session = session ,
7482 n_rows = arrow_table .num_rows ,
83+ scan_list = scan_list ,
7584 )
7685 return cls (node )
7786
@@ -104,14 +113,30 @@ def from_table(
104113 "Interpreting JSON column(s) as StringDtype. This behavior may change in future versions." ,
105114 bigframes .exceptions .PreviewWarning ,
106115 )
116+ # define data source only for needed columns, this makes row-hashing cheaper
117+ table_def = nodes .GbqTable .from_table (table , columns = schema .names )
118+
119+ # create ordering from info
120+ ordering = None
121+ if offsets_col :
122+ ordering = orderings .TotalOrdering .from_offset_col (offsets_col )
123+ elif primary_key :
124+ ordering = orderings .TotalOrdering .from_primary_key (primary_key )
125+
126+ # Scan all columns by default, we define this list as it can be pruned while preserving source_def
127+ scan_list = nodes .ScanList (
128+ tuple (
129+ nodes .ScanItem (ids .ColumnId (item .column ), item .dtype , item .column )
130+ for item in schema .items
131+ )
132+ )
133+ source_def = nodes .BigqueryDataSource (
134+ table = table_def , at_time = at_time , sql_predicate = predicate , ordering = ordering
135+ )
107136 node = nodes .ReadTableNode (
108- table = nodes .GbqTable .from_table (table ),
109- total_order_cols = (offsets_col ,) if offsets_col else tuple (primary_key ),
110- order_col_is_sequential = (offsets_col is not None ),
111- columns = schema ,
112- at_time = at_time ,
137+ source = source_def ,
138+ scan_list = scan_list ,
113139 table_session = session ,
114- sql_predicate = predicate ,
115140 )
116141 return cls (node )
117142
@@ -157,12 +182,22 @@ def as_cached(
157182 ordering : Optional [orderings .RowOrdering ],
158183 ) -> ArrayValue :
159184 """
160- Replace the node with an equivalent one that references a tabel where the value has been materialized to.
185+ Replace the node with an equivalent one that references a table where the value has been materialized to.
161186 """
187+ table = nodes .GbqTable .from_table (cache_table )
188+ source = nodes .BigqueryDataSource (table , ordering = ordering )
189+ # Assumption: GBQ cached table uses field name as bq column name
190+ scan_list = nodes .ScanList (
191+ tuple (
192+ nodes .ScanItem (field .id , field .dtype , field .id .name )
193+ for field in self .node .fields
194+ )
195+ )
162196 node = nodes .CachedTableNode (
163197 original_node = self .node ,
164- table = nodes .GbqTable .from_table (cache_table ),
165- ordering = ordering ,
198+ source = source ,
199+ table_session = self .session ,
200+ scan_list = scan_list ,
166201 )
167202 return ArrayValue (node )
168203
@@ -379,28 +414,34 @@ def relational_join(
379414 conditions : typing .Tuple [typing .Tuple [str , str ], ...] = (),
380415 type : typing .Literal ["inner" , "outer" , "left" , "right" , "cross" ] = "inner" ,
381416 ) -> typing .Tuple [ArrayValue , typing .Tuple [dict [str , str ], dict [str , str ]]]:
417+ l_mapping = { # Identity mapping, only rename right side
418+ lcol .name : lcol .name for lcol in self .node .ids
419+ }
420+ r_mapping = { # Rename conflicting names
421+ rcol .name : rcol .name
422+ if (rcol .name not in l_mapping )
423+ else bigframes .core .guid .generate_guid ()
424+ for rcol in other .node .ids
425+ }
426+ other_node = other .node
427+ if set (other_node .ids ) & set (self .node .ids ):
428+ other_node = nodes .SelectionNode (
429+ other_node ,
430+ tuple (
431+ (ex .deref (old_id ), ids .ColumnId (new_id ))
432+ for old_id , new_id in r_mapping .items ()
433+ ),
434+ )
435+
382436 join_node = nodes .JoinNode (
383437 left_child = self .node ,
384- right_child = other . node ,
438+ right_child = other_node ,
385439 conditions = tuple (
386- (ex .deref (l_col ), ex .deref (r_col )) for l_col , r_col in conditions
440+ (ex .deref (l_mapping [l_col ]), ex .deref (r_mapping [r_col ]))
441+ for l_col , r_col in conditions
387442 ),
388443 type = type ,
389444 )
390- # Maps input ids to output ids for caller convenience
391- l_size = len (self .node .schema )
392- l_mapping = {
393- lcol : ocol
394- for lcol , ocol in zip (
395- self .node .schema .names , join_node .schema .names [:l_size ]
396- )
397- }
398- r_mapping = {
399- rcol : ocol
400- for rcol , ocol in zip (
401- other .node .schema .names , join_node .schema .names [l_size :]
402- )
403- }
404445 return ArrayValue (join_node ), (l_mapping , r_mapping )
405446
406447 def try_align_as_projection (
0 commit comments