@@ -67,11 +67,20 @@ def from_pyarrow(cls, arrow_table: pa.Table, session: Session):
67
67
68
68
iobytes = io .BytesIO ()
69
69
pa_feather .write_feather (adapted_table , iobytes )
70
+ # Scan all columns by default, we define this list as it can be pruned while preserving source_def
71
+ scan_list = nodes .ScanList (
72
+ tuple (
73
+ nodes .ScanItem (ids .ColumnId (item .column ), item .dtype , item .column )
74
+ for item in schema .items
75
+ )
76
+ )
77
+
70
78
node = nodes .ReadLocalNode (
71
79
iobytes .getvalue (),
72
80
data_schema = schema ,
73
81
session = session ,
74
82
n_rows = arrow_table .num_rows ,
83
+ scan_list = scan_list ,
75
84
)
76
85
return cls (node )
77
86
@@ -104,14 +113,30 @@ def from_table(
104
113
"Interpreting JSON column(s) as StringDtype. This behavior may change in future versions." ,
105
114
bigframes .exceptions .PreviewWarning ,
106
115
)
116
+ # define data source only for needed columns, this makes row-hashing cheaper
117
+ table_def = nodes .GbqTable .from_table (table , columns = schema .names )
118
+
119
+ # create ordering from info
120
+ ordering = None
121
+ if offsets_col :
122
+ ordering = orderings .TotalOrdering .from_offset_col (offsets_col )
123
+ elif primary_key :
124
+ ordering = orderings .TotalOrdering .from_primary_key (primary_key )
125
+
126
+ # Scan all columns by default, we define this list as it can be pruned while preserving source_def
127
+ scan_list = nodes .ScanList (
128
+ tuple (
129
+ nodes .ScanItem (ids .ColumnId (item .column ), item .dtype , item .column )
130
+ for item in schema .items
131
+ )
132
+ )
133
+ source_def = nodes .BigqueryDataSource (
134
+ table = table_def , at_time = at_time , sql_predicate = predicate , ordering = ordering
135
+ )
107
136
node = nodes .ReadTableNode (
108
- table = nodes .GbqTable .from_table (table ),
109
- total_order_cols = (offsets_col ,) if offsets_col else tuple (primary_key ),
110
- order_col_is_sequential = (offsets_col is not None ),
111
- columns = schema ,
112
- at_time = at_time ,
137
+ source = source_def ,
138
+ scan_list = scan_list ,
113
139
table_session = session ,
114
- sql_predicate = predicate ,
115
140
)
116
141
return cls (node )
117
142
@@ -157,12 +182,22 @@ def as_cached(
157
182
ordering : Optional [orderings .RowOrdering ],
158
183
) -> ArrayValue :
159
184
"""
160
- Replace the node with an equivalent one that references a tabel where the value has been materialized to.
185
+ Replace the node with an equivalent one that references a table where the value has been materialized to.
161
186
"""
187
+ table = nodes .GbqTable .from_table (cache_table )
188
+ source = nodes .BigqueryDataSource (table , ordering = ordering )
189
+ # Assumption: GBQ cached table uses field name as bq column name
190
+ scan_list = nodes .ScanList (
191
+ tuple (
192
+ nodes .ScanItem (field .id , field .dtype , field .id .name )
193
+ for field in self .node .fields
194
+ )
195
+ )
162
196
node = nodes .CachedTableNode (
163
197
original_node = self .node ,
164
- table = nodes .GbqTable .from_table (cache_table ),
165
- ordering = ordering ,
198
+ source = source ,
199
+ table_session = self .session ,
200
+ scan_list = scan_list ,
166
201
)
167
202
return ArrayValue (node )
168
203
@@ -379,28 +414,34 @@ def relational_join(
379
414
conditions : typing .Tuple [typing .Tuple [str , str ], ...] = (),
380
415
type : typing .Literal ["inner" , "outer" , "left" , "right" , "cross" ] = "inner" ,
381
416
) -> typing .Tuple [ArrayValue , typing .Tuple [dict [str , str ], dict [str , str ]]]:
417
+ l_mapping = { # Identity mapping, only rename right side
418
+ lcol .name : lcol .name for lcol in self .node .ids
419
+ }
420
+ r_mapping = { # Rename conflicting names
421
+ rcol .name : rcol .name
422
+ if (rcol .name not in l_mapping )
423
+ else bigframes .core .guid .generate_guid ()
424
+ for rcol in other .node .ids
425
+ }
426
+ other_node = other .node
427
+ if set (other_node .ids ) & set (self .node .ids ):
428
+ other_node = nodes .SelectionNode (
429
+ other_node ,
430
+ tuple (
431
+ (ex .deref (old_id ), ids .ColumnId (new_id ))
432
+ for old_id , new_id in r_mapping .items ()
433
+ ),
434
+ )
435
+
382
436
join_node = nodes .JoinNode (
383
437
left_child = self .node ,
384
- right_child = other . node ,
438
+ right_child = other_node ,
385
439
conditions = tuple (
386
- (ex .deref (l_col ), ex .deref (r_col )) for l_col , r_col in conditions
440
+ (ex .deref (l_mapping [l_col ]), ex .deref (r_mapping [r_col ]))
441
+ for l_col , r_col in conditions
387
442
),
388
443
type = type ,
389
444
)
390
- # Maps input ids to output ids for caller convenience
391
- l_size = len (self .node .schema )
392
- l_mapping = {
393
- lcol : ocol
394
- for lcol , ocol in zip (
395
- self .node .schema .names , join_node .schema .names [:l_size ]
396
- )
397
- }
398
- r_mapping = {
399
- rcol : ocol
400
- for rcol , ocol in zip (
401
- other .node .schema .names , join_node .schema .names [l_size :]
402
- )
403
- }
404
445
return ArrayValue (join_node ), (l_mapping , r_mapping )
405
446
406
447
def try_align_as_projection (
0 commit comments