Skip to content

Commit cd2c729

Browse files
refactor: combine all projection nodes into single node type (#317)
1 parent 4eb64f6 commit cd2c729

File tree

9 files changed

+187
-184
lines changed

9 files changed

+187
-184
lines changed

bigframes/core/__init__.py

Lines changed: 91 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import pandas
2323

2424
import bigframes.core.compile as compiling
25-
import bigframes.core.expression as expressions
25+
import bigframes.core.expression as ex
2626
import bigframes.core.guid
2727
import bigframes.core.nodes as nodes
2828
from bigframes.core.ordering import OrderingColumnReference
@@ -114,12 +114,6 @@ def row_count(self) -> ArrayValue:
114114
return ArrayValue(nodes.RowCountNode(child=self.node))
115115

116116
# Operations
117-
118-
def drop_columns(self, columns: Iterable[str]) -> ArrayValue:
119-
return ArrayValue(
120-
nodes.DropColumnsNode(child=self.node, columns=tuple(columns))
121-
)
122-
123117
def filter(self, predicate_id: str, keep_null: bool = False) -> ArrayValue:
124118
"""Filter the table on a given expression, the predicate must be a boolean series aligned with the table expression."""
125119
return ArrayValue(
@@ -140,21 +134,104 @@ def promote_offsets(self, col_id: str) -> ArrayValue:
140134
"""
141135
return ArrayValue(nodes.PromoteOffsetsNode(child=self.node, col_id=col_id))
142136

143-
def select_columns(self, column_ids: typing.Sequence[str]) -> ArrayValue:
144-
return ArrayValue(
145-
nodes.SelectNode(child=self.node, column_ids=tuple(column_ids))
146-
)
147-
148137
def concat(self, other: typing.Sequence[ArrayValue]) -> ArrayValue:
149138
"""Append together multiple ArrayValue objects."""
150139
return ArrayValue(
151140
nodes.ConcatNode(children=tuple([self.node, *[val.node for val in other]]))
152141
)
153142

154-
def project(self, expression: expressions.Expression, output_id: str):
143+
def project_to_id(self, expression: ex.Expression, output_id: str):
144+
if output_id in self.column_ids: # Mutate case
145+
exprs = [
146+
((expression if (col_id == output_id) else ex.free_var(col_id)), col_id)
147+
for col_id in self.column_ids
148+
]
149+
else: # append case
150+
self_projection = (
151+
(ex.free_var(col_id), col_id) for col_id in self.column_ids
152+
)
153+
exprs = [*self_projection, (expression, output_id)]
154+
return ArrayValue(
155+
nodes.ProjectionNode(
156+
child=self.node,
157+
assignments=tuple(exprs),
158+
)
159+
)
160+
161+
def assign(self, source_id: str, destination_id: str) -> ArrayValue:
162+
if destination_id in self.column_ids: # Mutate case
163+
exprs = [
164+
(
165+
(
166+
ex.free_var(source_id)
167+
if (col_id == destination_id)
168+
else ex.free_var(col_id)
169+
),
170+
col_id,
171+
)
172+
for col_id in self.column_ids
173+
]
174+
else: # append case
175+
self_projection = (
176+
(ex.free_var(col_id), col_id) for col_id in self.column_ids
177+
)
178+
exprs = [*self_projection, (ex.free_var(source_id), destination_id)]
179+
return ArrayValue(
180+
nodes.ProjectionNode(
181+
child=self.node,
182+
assignments=tuple(exprs),
183+
)
184+
)
185+
186+
def assign_constant(
187+
self,
188+
destination_id: str,
189+
value: typing.Any,
190+
dtype: typing.Optional[bigframes.dtypes.Dtype],
191+
) -> ArrayValue:
192+
if destination_id in self.column_ids: # Mutate case
193+
exprs = [
194+
(
195+
(
196+
ex.const(value, dtype)
197+
if (col_id == destination_id)
198+
else ex.free_var(col_id)
199+
),
200+
col_id,
201+
)
202+
for col_id in self.column_ids
203+
]
204+
else: # append case
205+
self_projection = (
206+
(ex.free_var(col_id), col_id) for col_id in self.column_ids
207+
)
208+
exprs = [*self_projection, (ex.const(value, dtype), destination_id)]
209+
return ArrayValue(
210+
nodes.ProjectionNode(
211+
child=self.node,
212+
assignments=tuple(exprs),
213+
)
214+
)
215+
216+
def select_columns(self, column_ids: typing.Sequence[str]) -> ArrayValue:
217+
selections = ((ex.free_var(col_id), col_id) for col_id in column_ids)
218+
return ArrayValue(
219+
nodes.ProjectionNode(
220+
child=self.node,
221+
assignments=tuple(selections),
222+
)
223+
)
224+
225+
def drop_columns(self, columns: Iterable[str]) -> ArrayValue:
226+
new_projection = (
227+
(ex.free_var(col_id), col_id)
228+
for col_id in self.column_ids
229+
if col_id not in columns
230+
)
155231
return ArrayValue(
156232
nodes.ProjectionNode(
157-
child=self.node, assignments=((expression, output_id),)
233+
child=self.node,
234+
assignments=tuple(new_projection),
158235
)
159236
)
160237

@@ -277,25 +354,6 @@ def unpivot(
277354
)
278355
)
279356

280-
def assign(self, source_id: str, destination_id: str) -> ArrayValue:
281-
return ArrayValue(
282-
nodes.AssignNode(
283-
child=self.node, source_id=source_id, destination_id=destination_id
284-
)
285-
)
286-
287-
def assign_constant(
288-
self,
289-
destination_id: str,
290-
value: typing.Any,
291-
dtype: typing.Optional[bigframes.dtypes.Dtype],
292-
) -> ArrayValue:
293-
return ArrayValue(
294-
nodes.AssignConstantNode(
295-
child=self.node, destination_id=destination_id, value=value, dtype=dtype
296-
)
297-
)
298-
299357
def join(
300358
self,
301359
self_column_ids: typing.Sequence[str],

bigframes/core/blocks.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -671,7 +671,7 @@ def project_expr(
671671
"""
672672
# TODO(tbergeron): handle labels safely so callers don't need to
673673
result_id = guid.generate_guid()
674-
array_val = self._expr.project(expr, result_id)
674+
array_val = self._expr.project_to_id(expr, result_id)
675675
block = Block(
676676
array_val,
677677
index_columns=self.index_columns,
@@ -1226,11 +1226,11 @@ def add_prefix(self, prefix: str, axis: str | int | None = None) -> Block:
12261226
if axis_number == 0:
12271227
expr = self._expr
12281228
for index_col in self._index_columns:
1229-
add_prefix = ops.add_op.as_expr(
1230-
ex.const(prefix), ops.AsTypeOp(to_type="string").as_expr(index_col)
1231-
)
1232-
expr = expr.project(
1233-
expression=add_prefix,
1229+
expr = expr.project_to_id(
1230+
expression=ops.add_op.as_expr(
1231+
ex.const(prefix),
1232+
ops.AsTypeOp(to_type="string").as_expr(index_col),
1233+
),
12341234
output_id=index_col,
12351235
)
12361236
return Block(
@@ -1249,11 +1249,11 @@ def add_suffix(self, suffix: str, axis: str | int | None = None) -> Block:
12491249
if axis_number == 0:
12501250
expr = self._expr
12511251
for index_col in self._index_columns:
1252-
add_suffix = ops.add_op.as_expr(
1253-
ops.AsTypeOp(to_type="string").as_expr(index_col), ex.const(suffix)
1254-
)
1255-
expr = expr.project(
1256-
expression=add_suffix,
1252+
expr = expr.project_to_id(
1253+
expression=ops.add_op.as_expr(
1254+
ops.AsTypeOp(to_type="string").as_expr(index_col),
1255+
ex.const(suffix),
1256+
),
12571257
output_id=index_col,
12581258
)
12591259
return Block(
@@ -1557,7 +1557,7 @@ def merge(
15571557
coalesced_ids = []
15581558
for left_id, right_id in zip(left_join_ids, right_join_ids):
15591559
coalesced_id = guid.generate_guid()
1560-
joined_expr = joined_expr.project(
1560+
joined_expr = joined_expr.project_to_id(
15611561
ops.coalesce_op.as_expr(
15621562
get_column_left[left_id], get_column_right[right_id]
15631563
),

bigframes/core/compile/compiled.py

Lines changed: 44 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,8 @@
2626
import ibis.expr.types as ibis_types
2727
import pandas
2828

29-
import bigframes.constants as constants
3029
import bigframes.core.compile.scalar_op_compiler as op_compilers
31-
import bigframes.core.expression as expressions
30+
import bigframes.core.expression as ex
3231
import bigframes.core.guid
3332
from bigframes.core.ordering import (
3433
encode_order_string,
@@ -96,16 +95,6 @@ def _reduced_predicate(self) -> typing.Optional[ibis_types.BooleanValue]:
9695
else None
9796
)
9897

99-
@abc.abstractmethod
100-
def select_columns(self: T, column_ids: typing.Sequence[str]) -> T:
101-
"""Creates a new expression based on this expression with new columns."""
102-
...
103-
104-
def drop_columns(self: T, columns: Iterable[str]) -> T:
105-
return self.select_columns(
106-
[col for col in self.column_ids if col not in columns]
107-
)
108-
10998
@abc.abstractmethod
11099
def filter(self: T, predicate_id: str, keep_null: bool = False) -> T:
111100
"""Filter the table on a given expression, the predicate must be a boolean series aligned with the table expression."""
@@ -152,40 +141,26 @@ def _reproject_to_table(self: T) -> T:
152141
"""
153142
...
154143

155-
def project_expression(
144+
def projection(
156145
self: T,
157-
expression: expressions.Expression,
158-
output_column_id: typing.Optional[str] = None,
146+
expression_id_pairs: typing.Tuple[typing.Tuple[ex.Expression, str], ...],
159147
) -> T:
160148
"""Apply an expression to the ArrayValue and assign the output to a column."""
161-
result_id = (
162-
output_column_id or expression.unbound_variables[0]
163-
) # overwrite input if not output id provided
164-
bindings = {
165-
col: self._get_ibis_column(col) for col in expression.unbound_variables
166-
}
167-
value = op_compiler.compile_expression(expression, bindings).name(result_id)
168-
return self._set_or_replace_by_id(result_id, value)
149+
bindings = {col: self._get_ibis_column(col) for col in self.column_ids}
150+
values = [
151+
op_compiler.compile_expression(expression, bindings).name(id)
152+
for expression, id in expression_id_pairs
153+
]
154+
result = self._select(tuple(values)) # type: ignore
169155

170-
def assign(self: T, source_id: str, destination_id: str) -> T:
171-
return self._set_or_replace_by_id(
172-
destination_id, self._get_ibis_column(source_id)
173-
)
156+
# Need to reproject to convert ibis Scalar to ibis Column object
157+
if any(exp_id[0].is_const for exp_id in expression_id_pairs):
158+
result = result._reproject_to_table()
159+
return result
174160

175-
def assign_constant(
176-
self: T,
177-
destination_id: str,
178-
value: typing.Any,
179-
dtype: typing.Optional[bigframes.dtypes.Dtype],
180-
) -> T:
181-
# TODO(b/281587571): Solve scalar constant aggregation problem w/Ibis.
182-
ibis_value = bigframes.dtypes.literal_to_ibis_scalar(value, dtype)
183-
if ibis_value is None:
184-
raise NotImplementedError(
185-
f"Type not supported as scalar value {type(value)}. {constants.FEEDBACK_LINK}"
186-
)
187-
expr = self._set_or_replace_by_id(destination_id, ibis_value)
188-
return expr._reproject_to_table()
161+
@abc.abstractmethod
162+
def _select(self: T, values: typing.Tuple[ibis_types.Value]) -> T:
163+
...
189164

190165
@abc.abstractmethod
191166
def _set_or_replace_by_id(self: T, id: str, new_value: ibis_types.Value) -> T:
@@ -330,14 +305,6 @@ def _to_ibis_expr(
330305
table = table.filter(ibis.random() < ibis.literal(fraction))
331306
return table
332307

333-
def select_columns(self, column_ids: typing.Sequence[str]) -> UnorderedIR:
334-
"""Creates a new expression based on this expression with new columns."""
335-
columns = [self._get_ibis_column(col_id) for col_id in column_ids]
336-
builder = self.builder()
337-
builder.columns = list(columns)
338-
new_expr = builder.build()
339-
return new_expr
340-
341308
def filter(self, predicate_id: str, keep_null: bool = False) -> UnorderedIR:
342309
condition = typing.cast(
343310
ibis_types.BooleanValue, self._get_ibis_column(predicate_id)
@@ -577,6 +544,11 @@ def _set_or_replace_by_id(
577544
builder.columns = [*self.columns, new_value.name(id)]
578545
return builder.build()
579546

547+
def _select(self, values: typing.Tuple[ibis_types.Value]) -> UnorderedIR:
548+
builder = self.builder()
549+
builder.columns = values
550+
return builder.build()
551+
580552
def _reproject_to_table(self) -> UnorderedIR:
581553
"""
582554
Internal operators that projects the internal representation into a
@@ -816,20 +788,6 @@ def promote_offsets(self, col_id: str) -> OrderedIR:
816788
]
817789
return expr_builder.build()
818790

819-
def select_columns(self, column_ids: typing.Sequence[str]) -> OrderedIR:
820-
"""Creates a new expression based on this expression with new columns."""
821-
columns = [self._get_ibis_column(col_id) for col_id in column_ids]
822-
expr = self
823-
for ordering_column in set(self.column_ids).intersection(
824-
[col_ref.column_id for col_ref in self._ordering.ordering_value_columns]
825-
):
826-
# Need to hide ordering columns that are being dropped. Alternatively, could project offsets
827-
expr = expr._hide_column(ordering_column)
828-
builder = expr.builder()
829-
builder.columns = list(columns)
830-
new_expr = builder.build()
831-
return new_expr
832-
833791
## Methods that only work with ordering
834792
def project_window_op(
835793
self,
@@ -1221,6 +1179,29 @@ def _set_or_replace_by_id(self, id: str, new_value: ibis_types.Value) -> Ordered
12211179
builder.columns = [*self.columns, new_value.name(id)]
12221180
return builder.build()
12231181

1182+
def _select(self, values: typing.Tuple[ibis_types.Value]) -> OrderedIR:
1183+
"""Safely assign by id while maintaining ordering integrity."""
1184+
# TODO: Split into explicit set and replace methods
1185+
ordering_col_ids = [
1186+
col_ref.column_id for col_ref in self._ordering.ordering_value_columns
1187+
]
1188+
ir = self
1189+
mappings = {value.name: value for value in values}
1190+
for ordering_id in ordering_col_ids:
1191+
# Drop case
1192+
if (ordering_id not in mappings) and (ordering_id in ir.column_ids):
1193+
# id is being dropped, hide it first
1194+
ir = ir._hide_column(ordering_id)
1195+
# Mutate case
1196+
elif (ordering_id in mappings) and not mappings[ordering_id].equals(
1197+
ir._get_any_column(ordering_id)
1198+
):
1199+
ir = ir._hide_column(ordering_id)
1200+
1201+
builder = ir.builder()
1202+
builder.columns = list(values)
1203+
return builder.build()
1204+
12241205
## Ordering specific helpers
12251206
def _get_any_column(self, key: str) -> ibis_types.Value:
12261207
"""Gets the Ibis expression for a given column. Will also get hidden columns."""

0 commit comments

Comments
 (0)