Skip to content

Commit 4105dba

Browse files
refactor: Make expression nodes prunable (#1030)
1 parent 057f3f0 commit 4105dba

File tree

9 files changed

+358
-257
lines changed

9 files changed

+358
-257
lines changed

bigframes/core/__init__.py

Lines changed: 66 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,20 @@ def from_pyarrow(cls, arrow_table: pa.Table, session: Session):
6767

6868
iobytes = io.BytesIO()
6969
pa_feather.write_feather(adapted_table, iobytes)
70+
# Scan all columns by default, we define this list as it can be pruned while preserving source_def
71+
scan_list = nodes.ScanList(
72+
tuple(
73+
nodes.ScanItem(ids.ColumnId(item.column), item.dtype, item.column)
74+
for item in schema.items
75+
)
76+
)
77+
7078
node = nodes.ReadLocalNode(
7179
iobytes.getvalue(),
7280
data_schema=schema,
7381
session=session,
7482
n_rows=arrow_table.num_rows,
83+
scan_list=scan_list,
7584
)
7685
return cls(node)
7786

@@ -104,14 +113,30 @@ def from_table(
104113
"Interpreting JSON column(s) as StringDtype. This behavior may change in future versions.",
105114
bigframes.exceptions.PreviewWarning,
106115
)
116+
# define data source only for needed columns, this makes row-hashing cheaper
117+
table_def = nodes.GbqTable.from_table(table, columns=schema.names)
118+
119+
# create ordering from info
120+
ordering = None
121+
if offsets_col:
122+
ordering = orderings.TotalOrdering.from_offset_col(offsets_col)
123+
elif primary_key:
124+
ordering = orderings.TotalOrdering.from_primary_key(primary_key)
125+
126+
# Scan all columns by default, we define this list as it can be pruned while preserving source_def
127+
scan_list = nodes.ScanList(
128+
tuple(
129+
nodes.ScanItem(ids.ColumnId(item.column), item.dtype, item.column)
130+
for item in schema.items
131+
)
132+
)
133+
source_def = nodes.BigqueryDataSource(
134+
table=table_def, at_time=at_time, sql_predicate=predicate, ordering=ordering
135+
)
107136
node = nodes.ReadTableNode(
108-
table=nodes.GbqTable.from_table(table),
109-
total_order_cols=(offsets_col,) if offsets_col else tuple(primary_key),
110-
order_col_is_sequential=(offsets_col is not None),
111-
columns=schema,
112-
at_time=at_time,
137+
source=source_def,
138+
scan_list=scan_list,
113139
table_session=session,
114-
sql_predicate=predicate,
115140
)
116141
return cls(node)
117142

@@ -157,12 +182,22 @@ def as_cached(
157182
ordering: Optional[orderings.RowOrdering],
158183
) -> ArrayValue:
159184
"""
160-
Replace the node with an equivalent one that references a tabel where the value has been materialized to.
185+
Replace the node with an equivalent one that references a table where the value has been materialized to.
161186
"""
187+
table = nodes.GbqTable.from_table(cache_table)
188+
source = nodes.BigqueryDataSource(table, ordering=ordering)
189+
# Assumption: GBQ cached table uses field name as bq column name
190+
scan_list = nodes.ScanList(
191+
tuple(
192+
nodes.ScanItem(field.id, field.dtype, field.id.name)
193+
for field in self.node.fields
194+
)
195+
)
162196
node = nodes.CachedTableNode(
163197
original_node=self.node,
164-
table=nodes.GbqTable.from_table(cache_table),
165-
ordering=ordering,
198+
source=source,
199+
table_session=self.session,
200+
scan_list=scan_list,
166201
)
167202
return ArrayValue(node)
168203

@@ -379,28 +414,34 @@ def relational_join(
379414
conditions: typing.Tuple[typing.Tuple[str, str], ...] = (),
380415
type: typing.Literal["inner", "outer", "left", "right", "cross"] = "inner",
381416
) -> typing.Tuple[ArrayValue, typing.Tuple[dict[str, str], dict[str, str]]]:
417+
l_mapping = { # Identity mapping, only rename right side
418+
lcol.name: lcol.name for lcol in self.node.ids
419+
}
420+
r_mapping = { # Rename conflicting names
421+
rcol.name: rcol.name
422+
if (rcol.name not in l_mapping)
423+
else bigframes.core.guid.generate_guid()
424+
for rcol in other.node.ids
425+
}
426+
other_node = other.node
427+
if set(other_node.ids) & set(self.node.ids):
428+
other_node = nodes.SelectionNode(
429+
other_node,
430+
tuple(
431+
(ex.deref(old_id), ids.ColumnId(new_id))
432+
for old_id, new_id in r_mapping.items()
433+
),
434+
)
435+
382436
join_node = nodes.JoinNode(
383437
left_child=self.node,
384-
right_child=other.node,
438+
right_child=other_node,
385439
conditions=tuple(
386-
(ex.deref(l_col), ex.deref(r_col)) for l_col, r_col in conditions
440+
(ex.deref(l_mapping[l_col]), ex.deref(r_mapping[r_col]))
441+
for l_col, r_col in conditions
387442
),
388443
type=type,
389444
)
390-
# Maps input ids to output ids for caller convenience
391-
l_size = len(self.node.schema)
392-
l_mapping = {
393-
lcol: ocol
394-
for lcol, ocol in zip(
395-
self.node.schema.names, join_node.schema.names[:l_size]
396-
)
397-
}
398-
r_mapping = {
399-
rcol: ocol
400-
for rcol, ocol in zip(
401-
other.node.schema.names, join_node.schema.names[l_size:]
402-
)
403-
}
404445
return ArrayValue(join_node), (l_mapping, r_mapping)
405446

406447
def try_align_as_projection(

bigframes/core/compile/compiled.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
RowOrdering,
4646
TotalOrdering,
4747
)
48-
import bigframes.core.schema as schemata
4948
import bigframes.core.sql
5049
from bigframes.core.window_spec import RangeWindowBounds, RowsWindowBounds, WindowSpec
5150
import bigframes.dtypes
@@ -585,9 +584,7 @@ def has_total_order(self) -> bool:
585584

586585
@classmethod
587586
def from_pandas(
588-
cls,
589-
pd_df: pandas.DataFrame,
590-
schema: schemata.ArraySchema,
587+
cls, pd_df: pandas.DataFrame, scan_cols: bigframes.core.nodes.ScanList
591588
) -> OrderedIR:
592589
"""
593590
Builds an in-memory only (SQL only) expr from a pandas dataframe.
@@ -603,18 +600,21 @@ def from_pandas(
603600
# derive the ibis schema from the original pandas schema
604601
ibis_schema = [
605602
(
606-
name,
603+
local_label,
607604
bigframes.core.compile.ibis_types.bigframes_dtype_to_ibis_dtype(dtype),
608605
)
609-
for name, dtype in zip(schema.names, schema.dtypes)
606+
for id, dtype, local_label in scan_cols.items
610607
]
611608
ibis_schema.append((ORDER_ID_COLUMN, ibis_dtypes.int64))
612609

613610
keys_memtable = ibis.memtable(ibis_values, schema=ibis.schema(ibis_schema))
614611

615612
return cls(
616613
keys_memtable,
617-
columns=[keys_memtable[column].name(column) for column in pd_df.columns],
614+
columns=[
615+
keys_memtable[local_label].name(col_id.sql)
616+
for col_id, _, local_label in scan_cols.items
617+
],
618618
ordering=TotalOrdering.from_offset_col(ORDER_ID_COLUMN),
619619
hidden_ordering_columns=(keys_memtable[ORDER_ID_COLUMN],),
620620
)

0 commit comments

Comments
 (0)