Skip to content

Commit 71a8ab9

Browse files
refactor: Simplify projection nodes (#961)
1 parent c1cde19 commit 71a8ab9

File tree

11 files changed

+164
-91
lines changed

11 files changed

+164
-91
lines changed

bigframes/core/__init__.py

Lines changed: 23 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -192,49 +192,38 @@ def concat(self, other: typing.Sequence[ArrayValue]) -> ArrayValue:
192192
)
193193

194194
def project_to_id(self, expression: ex.Expression, output_id: str):
195-
if output_id in self.column_ids: # Mutate case
196-
exprs = [
197-
((expression if (col_id == output_id) else ex.free_var(col_id)), col_id)
198-
for col_id in self.column_ids
199-
]
200-
else: # append case
201-
self_projection = (
202-
(ex.free_var(col_id), col_id) for col_id in self.column_ids
203-
)
204-
exprs = [*self_projection, (expression, output_id)]
205195
return ArrayValue(
206196
nodes.ProjectionNode(
207197
child=self.node,
208-
assignments=tuple(exprs),
198+
assignments=(
199+
(
200+
expression,
201+
output_id,
202+
),
203+
),
209204
)
210205
)
211206

212207
def assign(self, source_id: str, destination_id: str) -> ArrayValue:
213208
if destination_id in self.column_ids: # Mutate case
214209
exprs = [
215210
(
216-
(
217-
ex.free_var(source_id)
218-
if (col_id == destination_id)
219-
else ex.free_var(col_id)
220-
),
211+
(source_id if (col_id == destination_id) else col_id),
221212
col_id,
222213
)
223214
for col_id in self.column_ids
224215
]
225216
else: # append case
226-
self_projection = (
227-
(ex.free_var(col_id), col_id) for col_id in self.column_ids
228-
)
229-
exprs = [*self_projection, (ex.free_var(source_id), destination_id)]
217+
self_projection = ((col_id, col_id) for col_id in self.column_ids)
218+
exprs = [*self_projection, (source_id, destination_id)]
230219
return ArrayValue(
231-
nodes.ProjectionNode(
220+
nodes.SelectionNode(
232221
child=self.node,
233-
assignments=tuple(exprs),
222+
input_output_pairs=tuple(exprs),
234223
)
235224
)
236225

237-
def assign_constant(
226+
def create_constant(
238227
self,
239228
destination_id: str,
240229
value: typing.Any,
@@ -244,49 +233,31 @@ def assign_constant(
244233
# Need to assign a data type when value is NaN.
245234
dtype = dtype or bigframes.dtypes.DEFAULT_DTYPE
246235

247-
if destination_id in self.column_ids: # Mutate case
248-
exprs = [
249-
(
250-
(
251-
ex.const(value, dtype)
252-
if (col_id == destination_id)
253-
else ex.free_var(col_id)
254-
),
255-
col_id,
256-
)
257-
for col_id in self.column_ids
258-
]
259-
else: # append case
260-
self_projection = (
261-
(ex.free_var(col_id), col_id) for col_id in self.column_ids
262-
)
263-
exprs = [*self_projection, (ex.const(value, dtype), destination_id)]
264236
return ArrayValue(
265237
nodes.ProjectionNode(
266238
child=self.node,
267-
assignments=tuple(exprs),
239+
assignments=((ex.const(value, dtype), destination_id),),
268240
)
269241
)
270242

271243
def select_columns(self, column_ids: typing.Sequence[str]) -> ArrayValue:
272-
selections = ((ex.free_var(col_id), col_id) for col_id in column_ids)
244+
# This basically just drops and reorders columns - logically a no-op except as a final step
245+
selections = ((col_id, col_id) for col_id in column_ids)
273246
return ArrayValue(
274-
nodes.ProjectionNode(
247+
nodes.SelectionNode(
275248
child=self.node,
276-
assignments=tuple(selections),
249+
input_output_pairs=tuple(selections),
277250
)
278251
)
279252

280253
def drop_columns(self, columns: Iterable[str]) -> ArrayValue:
281254
new_projection = (
282-
(ex.free_var(col_id), col_id)
283-
for col_id in self.column_ids
284-
if col_id not in columns
255+
(col_id, col_id) for col_id in self.column_ids if col_id not in columns
285256
)
286257
return ArrayValue(
287-
nodes.ProjectionNode(
258+
nodes.SelectionNode(
288259
child=self.node,
289-
assignments=tuple(new_projection),
260+
input_output_pairs=tuple(new_projection),
290261
)
291262
)
292263

@@ -422,15 +393,13 @@ def unpivot(
422393
col_expr = ops.case_when_op.as_expr(*cases)
423394
unpivot_exprs.append((col_expr, col_id))
424395

425-
label_exprs = ((ex.free_var(id), id) for id in index_col_ids)
426-
# passthrough columns are unchanged, just repeated N times each
427-
passthrough_exprs = ((ex.free_var(id), id) for id in passthrough_columns)
396+
unpivot_col_ids = [id for id, _ in unpivot_columns]
428397
return ArrayValue(
429398
nodes.ProjectionNode(
430399
child=joined_array.node,
431-
assignments=(*label_exprs, *unpivot_exprs, *passthrough_exprs),
400+
assignments=(*unpivot_exprs,),
432401
)
433-
)
402+
).select_columns([*index_col_ids, *unpivot_col_ids, *passthrough_columns])
434403

435404
def _cross_join_w_labels(
436405
self, labels_array: ArrayValue, join_side: typing.Literal["left", "right"]

bigframes/core/blocks.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -939,7 +939,7 @@ def multi_apply_unary_op(
939939
for col_id in columns:
940940
label = self.col_id_to_label[col_id]
941941
block, result_id = block.project_expr(
942-
expr.bind_all_variables({input_varname: ex.free_var(col_id)}),
942+
expr.bind_variables({input_varname: ex.free_var(col_id)}),
943943
label=label,
944944
)
945945
block = block.copy_values(result_id, col_id)
@@ -1006,7 +1006,7 @@ def create_constant(
10061006
dtype: typing.Optional[bigframes.dtypes.Dtype] = None,
10071007
) -> typing.Tuple[Block, str]:
10081008
result_id = guid.generate_guid()
1009-
expr = self.expr.assign_constant(result_id, scalar_constant, dtype=dtype)
1009+
expr = self.expr.create_constant(result_id, scalar_constant, dtype=dtype)
10101010
# Create index copy with label inserted
10111011
# See: https://pandas.pydata.org/docs/reference/api/pandas.Index.insert.html
10121012
labels = self.column_labels.insert(len(self.column_labels), label)
@@ -1067,7 +1067,7 @@ def aggregate_all_and_stack(
10671067
index_id = guid.generate_guid()
10681068
result_expr = self.expr.aggregate(
10691069
aggregations, dropna=dropna
1070-
).assign_constant(index_id, None, None)
1070+
).create_constant(index_id, None, None)
10711071
# Transpose as last operation so that final block has valid transpose cache
10721072
return Block(
10731073
result_expr,
@@ -1222,7 +1222,7 @@ def aggregate(
12221222
names: typing.List[Label] = []
12231223
if len(by_column_ids) == 0:
12241224
label_id = guid.generate_guid()
1225-
result_expr = result_expr.assign_constant(label_id, 0, pd.Int64Dtype())
1225+
result_expr = result_expr.create_constant(label_id, 0, pd.Int64Dtype())
12261226
index_columns = (label_id,)
12271227
names = [None]
12281228
else:
@@ -1614,17 +1614,22 @@ def add_prefix(self, prefix: str, axis: str | int | None = None) -> Block:
16141614
axis_number = utils.get_axis_number("rows" if (axis is None) else axis)
16151615
if axis_number == 0:
16161616
expr = self._expr
1617+
new_index_cols = []
16171618
for index_col in self._index_columns:
1619+
new_col = guid.generate_guid()
16181620
expr = expr.project_to_id(
16191621
expression=ops.add_op.as_expr(
16201622
ex.const(prefix),
16211623
ops.AsTypeOp(to_type="string").as_expr(index_col),
16221624
),
1623-
output_id=index_col,
1625+
output_id=new_col,
16241626
)
1627+
new_index_cols.append(new_col)
1628+
expr = expr.select_columns((*new_index_cols, *self.value_columns))
1629+
16251630
return Block(
16261631
expr,
1627-
index_columns=self.index_columns,
1632+
index_columns=new_index_cols,
16281633
column_labels=self.column_labels,
16291634
index_labels=self.index.names,
16301635
)
@@ -1635,17 +1640,21 @@ def add_suffix(self, suffix: str, axis: str | int | None = None) -> Block:
16351640
axis_number = utils.get_axis_number("rows" if (axis is None) else axis)
16361641
if axis_number == 0:
16371642
expr = self._expr
1643+
new_index_cols = []
16381644
for index_col in self._index_columns:
1645+
new_col = guid.generate_guid()
16391646
expr = expr.project_to_id(
16401647
expression=ops.add_op.as_expr(
16411648
ops.AsTypeOp(to_type="string").as_expr(index_col),
16421649
ex.const(suffix),
16431650
),
1644-
output_id=index_col,
1651+
output_id=new_col,
16451652
)
1653+
new_index_cols.append(new_col)
1654+
expr = expr.select_columns((*new_index_cols, *self.value_columns))
16461655
return Block(
16471656
expr,
1648-
index_columns=self.index_columns,
1657+
index_columns=new_index_cols,
16491658
column_labels=self.column_labels,
16501659
index_labels=self.index.names,
16511660
)

bigframes/core/compile/compiled.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,23 @@ def projection(
134134
) -> T:
135135
"""Apply an expression to the ArrayValue and assign the output to a column."""
136136
bindings = {col: self._get_ibis_column(col) for col in self.column_ids}
137-
values = [
137+
new_values = [
138138
op_compiler.compile_expression(expression, bindings).name(id)
139139
for expression, id in expression_id_pairs
140140
]
141+
result = self._select(tuple([*self._columns, *new_values])) # type: ignore
142+
return result
143+
144+
def selection(
145+
self: T,
146+
input_output_pairs: typing.Tuple[typing.Tuple[str, str], ...],
147+
) -> T:
148+
"""Apply an expression to the ArrayValue and assign the output to a column."""
149+
bindings = {col: self._get_ibis_column(col) for col in self.column_ids}
150+
values = [
151+
op_compiler.compile_expression(ex.free_var(input), bindings).name(id)
152+
for input, id in input_output_pairs
153+
]
141154
result = self._select(tuple(values)) # type: ignore
142155
return result
143156

bigframes/core/compile/compiler.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,11 @@ def compile_reversed(self, node: nodes.ReversedNode, ordered: bool = True):
264264
else:
265265
return self.compile_unordered_ir(node.child)
266266

267+
@_compile_node.register
268+
def compile_selection(self, node: nodes.SelectionNode, ordered: bool = True):
269+
result = self.compile_node(node.child, ordered)
270+
return result.selection(node.input_output_pairs)
271+
267272
@_compile_node.register
268273
def compile_projection(self, node: nodes.ProjectionNode, ordered: bool = True):
269274
result = self.compile_node(node.child, ordered)

bigframes/core/expression.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,13 @@ def output_type(
110110
...
111111

112112
@abc.abstractmethod
113-
def bind_all_variables(self, bindings: Mapping[str, Expression]) -> Expression:
114-
"""Replace all variables with expression given in `bindings`."""
113+
def bind_variables(
114+
self, bindings: Mapping[str, Expression], check_bind_all: bool = True
115+
) -> Expression:
116+
"""Replace variables with expression given in `bindings`.
117+
118+
If check_bind_all is True, validate that all free variables are bound to a new value.
119+
"""
115120
...
116121

117122
@property
@@ -141,7 +146,9 @@ def output_type(
141146
) -> dtypes.ExpressionType:
142147
return self.dtype
143148

144-
def bind_all_variables(self, bindings: Mapping[str, Expression]) -> Expression:
149+
def bind_variables(
150+
self, bindings: Mapping[str, Expression], check_bind_all: bool = True
151+
) -> Expression:
145152
return self
146153

147154
@property
@@ -178,11 +185,14 @@ def output_type(
178185
else:
179186
raise ValueError(f"Type of variable {self.id} has not been fixed.")
180187

181-
def bind_all_variables(self, bindings: Mapping[str, Expression]) -> Expression:
188+
def bind_variables(
189+
self, bindings: Mapping[str, Expression], check_bind_all: bool = True
190+
) -> Expression:
182191
if self.id in bindings.keys():
183192
return bindings[self.id]
184-
else:
193+
elif check_bind_all:
185194
raise ValueError(f"Variable {self.id} remains unbound")
195+
return self
186196

187197
@property
188198
def is_bijective(self) -> bool:
@@ -225,10 +235,15 @@ def output_type(
225235
)
226236
return self.op.output_type(*operand_types)
227237

228-
def bind_all_variables(self, bindings: Mapping[str, Expression]) -> Expression:
238+
def bind_variables(
239+
self, bindings: Mapping[str, Expression], check_bind_all: bool = True
240+
) -> Expression:
229241
return OpExpression(
230242
self.op,
231-
tuple(input.bind_all_variables(bindings) for input in self.inputs),
243+
tuple(
244+
input.bind_variables(bindings, check_bind_all=check_bind_all)
245+
for input in self.inputs
246+
),
232247
)
233248

234249
@property

bigframes/core/nodes.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -622,15 +622,41 @@ def relation_ops_created(self) -> int:
622622
return 0
623623

624624

625+
@dataclass(frozen=True)
626+
class SelectionNode(UnaryNode):
627+
input_output_pairs: typing.Tuple[typing.Tuple[str, str], ...]
628+
629+
def __hash__(self):
630+
return self._node_hash
631+
632+
@functools.cached_property
633+
def schema(self) -> schemata.ArraySchema:
634+
input_types = self.child.schema._mapping
635+
items = tuple(
636+
schemata.SchemaItem(output, input_types[input])
637+
for input, output in self.input_output_pairs
638+
)
639+
return schemata.ArraySchema(items)
640+
641+
@property
642+
def variables_introduced(self) -> int:
643+
# This operation only renames variables, doesn't actually create new ones
644+
return 0
645+
646+
625647
@dataclass(frozen=True)
626648
class ProjectionNode(UnaryNode):
649+
"""Assigns new variables (without modifying existing ones)"""
650+
627651
assignments: typing.Tuple[typing.Tuple[ex.Expression, str], ...]
628652

629653
def __post_init__(self):
630654
input_types = self.child.schema._mapping
631655
for expression, id in self.assignments:
632656
# throws TypeError if invalid
633657
_ = expression.output_type(input_types)
658+
# Cannot assign to existing variables - append only!
659+
assert all(name not in self.child.schema.names for _, name in self.assignments)
634660

635661
def __hash__(self):
636662
return self._node_hash
@@ -644,7 +670,10 @@ def schema(self) -> schemata.ArraySchema:
644670
)
645671
for ex, id in self.assignments
646672
)
647-
return schemata.ArraySchema(items)
673+
schema = self.child.schema
674+
for item in items:
675+
schema = schema.append(item)
676+
return schema
648677

649678
@property
650679
def variables_introduced(self) -> int:

bigframes/core/ordering.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def bind_variables(
6363
self, mapping: Mapping[str, expression.Expression]
6464
) -> OrderingExpression:
6565
return OrderingExpression(
66-
self.scalar_expression.bind_all_variables(mapping),
66+
self.scalar_expression.bind_variables(mapping),
6767
self.direction,
6868
self.na_last,
6969
)

0 commit comments

Comments
 (0)