Skip to content

Commit 9597ba9

Browse files
refactor: Make window op node support non-unary ops (#1295)
1 parent 8b8155f commit 9597ba9

File tree

7 files changed

+65
-38
lines changed

7 files changed

+65
-38
lines changed

bigframes/core/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -405,8 +405,7 @@ def project_window_op(
405405
ArrayValue(
406406
nodes.WindowOpNode(
407407
child=self.node,
408-
column_name=ex.deref(column_name),
409-
op=op,
408+
expression=ex.UnaryAggregation(op, ex.deref(column_name)),
410409
window_spec=window_spec,
411410
output_name=ids.ColumnId(output_name),
412411
never_skip_nulls=never_skip_nulls,

bigframes/core/compile/aggregate_compiler.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -479,10 +479,9 @@ def _(
479479
return _apply_window_if_present(column.dense_rank(), window) + 1
480480

481481

482-
@compile_unary_agg.register
482+
@compile_nullary_agg.register
483483
def _(
484484
op: agg_ops.RowNumberOp,
485-
column: ibis_types.Column,
486485
window=None,
487486
) -> ibis_types.IntegerValue:
488487
return _apply_window_if_present(ibis_api.row_number(), window)

bigframes/core/compile/compiled.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -861,8 +861,7 @@ def promote_offsets(self, col_id: str) -> OrderedIR:
861861
## Methods that only work with ordering
862862
def project_window_op(
863863
self,
864-
column_name: ex.DerefOp,
865-
op: agg_ops.UnaryWindowOp,
864+
expression: ex.Aggregation,
866865
window_spec: WindowSpec,
867866
output_name: str,
868867
*,
@@ -881,53 +880,66 @@ def project_window_op(
881880
# See: https://github.com/ibis-project/ibis/issues/9773
882881
used_exprs = map(
883882
self._compile_expression,
884-
itertools.chain(
885-
(column_name,), map(ex.DerefOp, window_spec.all_referenced_columns)
883+
map(
884+
ex.DerefOp,
885+
itertools.chain(
886+
expression.column_references, window_spec.all_referenced_columns
887+
),
886888
),
887889
)
888890
can_directly_window = not any(
889891
map(lambda x: is_literal(x) or is_window(x), used_exprs)
890892
)
891893
if not can_directly_window:
892894
return self._reproject_to_table().project_window_op(
893-
column_name,
894-
op,
895+
expression,
895896
window_spec,
896897
output_name,
897898
never_skip_nulls=never_skip_nulls,
898899
)
899900

900-
column = typing.cast(ibis_types.Column, self._compile_expression(column_name))
901901
window = self._ibis_window_from_spec(
902-
window_spec, require_total_order=op.uses_total_row_ordering
902+
window_spec, require_total_order=expression.op.uses_total_row_ordering
903903
)
904904
bindings = {col: self._get_ibis_column(col) for col in self.column_ids}
905905

906906
window_op = agg_compiler.compile_analytic(
907-
ex.UnaryAggregation(op, column_name),
907+
expression,
908908
window,
909909
bindings=bindings,
910910
)
911911

912+
inputs = tuple(
913+
typing.cast(ibis_types.Column, self._compile_expression(ex.DerefOp(column)))
914+
for column in expression.column_references
915+
)
912916
clauses = []
913-
if op.skips_nulls and not never_skip_nulls:
914-
clauses.append((column.isnull(), ibis_types.null()))
915-
if window_spec.min_periods:
916-
if op.skips_nulls:
917+
if expression.op.skips_nulls and not never_skip_nulls:
918+
for column in inputs:
919+
clauses.append((column.isnull(), ibis_types.null()))
920+
if window_spec.min_periods and len(inputs) > 0:
921+
if expression.op.skips_nulls:
917922
# Most operations do not count NULL values towards min_periods
923+
per_col_does_count = (column.notnull() for column in inputs)
924+
# All inputs must be non-null for observation to count
925+
is_observation = functools.reduce(
926+
lambda x, y: x & y, per_col_does_count
927+
).cast(int)
918928
observation_count = agg_compiler.compile_analytic(
919-
ex.UnaryAggregation(agg_ops.count_op, column_name),
929+
ex.UnaryAggregation(agg_ops.sum_op, ex.deref("_observation_count")),
920930
window,
921-
bindings=bindings,
931+
bindings={"_observation_count": is_observation},
922932
)
923933
else:
924934
# Operations like count treat even NULLs as valid observations for the sake of min_periods
925935
# notnull is just used to convert null values to non-null (FALSE) values to be counted
926-
denulled_value = typing.cast(ibis_types.BooleanColumn, column.notnull())
936+
is_observation = inputs[0].notnull()
927937
observation_count = agg_compiler.compile_analytic(
928-
ex.UnaryAggregation(agg_ops.count_op, ex.deref("_denulled")),
938+
ex.UnaryAggregation(
939+
agg_ops.count_op, ex.deref("_observation_count")
940+
),
929941
window,
930-
bindings={**bindings, "_denulled": denulled_value},
942+
bindings={"_observation_count": is_observation},
931943
)
932944
clauses.append(
933945
(

bigframes/core/compile/compiler.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -364,8 +364,7 @@ def compile_aggregate(self, node: nodes.AggregateNode, ordered: bool = True):
364364
@_compile_node.register
365365
def compile_window(self, node: nodes.WindowOpNode, ordered: bool = True):
366366
result = self.compile_ordered_ir(node.child).project_window_op(
367-
node.column_name,
368-
node.op,
367+
node.expression,
369368
node.window_spec,
370369
node.output_name.sql,
371370
never_skip_nulls=node.never_skip_nulls,

bigframes/core/compile/polars/compiler.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import dataclasses
1717
import functools
1818
import itertools
19-
from typing import cast, Sequence, TYPE_CHECKING
19+
from typing import cast, Sequence, Tuple, TYPE_CHECKING
2020

2121
import bigframes.core
2222
import bigframes.core.expression as ex
@@ -125,6 +125,24 @@ def get_args(
125125
f"Aggregation {agg} not yet supported in polars engine."
126126
)
127127

128+
def compile_agg_expr(self, expr: ex.Aggregation):
129+
if isinstance(expr, ex.NullaryAggregation):
130+
inputs: Tuple = ()
131+
elif isinstance(expr, ex.UnaryAggregation):
132+
assert isinstance(expr.arg, ex.DerefOp)
133+
inputs = (expr.arg.id.sql,)
134+
elif isinstance(expr, ex.BinaryAggregation):
135+
assert isinstance(expr.left, ex.DerefOp)
136+
assert isinstance(expr.right, ex.DerefOp)
137+
inputs = (
138+
expr.left.id.sql,
139+
expr.right.id.sql,
140+
)
141+
else:
142+
raise ValueError(f"Unexpected aggregation: {expr.op}")
143+
144+
return self.compile_agg_op(expr.op, inputs)
145+
128146
def compile_agg_op(self, op: agg_ops.WindowOp, inputs: Sequence[str] = []):
129147
if isinstance(op, agg_ops.ProductOp):
130148
# TODO: Need schema to cast back to original type if posisble (eg float back to int)
@@ -320,9 +338,9 @@ def compile_sample(self, node: nodes.RandomSampleNode):
320338
@compile_node.register
321339
def compile_window(self, node: nodes.WindowOpNode):
322340
df = self.compile_node(node.child)
323-
agg_expr = self.agg_compiler.compile_agg_op(
324-
node.op, [node.column_name.id.sql]
325-
).alias(node.output_name.sql)
341+
agg_expr = self.agg_compiler.compile_agg_expr(node.expression).alias(
342+
node.output_name.sql
343+
)
326344
# Three window types: completely unbound, grouped and row bounded
327345

328346
window = node.window_spec

bigframes/core/nodes.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
import bigframes.core.slices as slices
3434
import bigframes.core.window_spec as window
3535
import bigframes.dtypes
36-
import bigframes.operations.aggregations as agg_ops
3736

3837
if typing.TYPE_CHECKING:
3938
import bigframes.core.ordering as orderings
@@ -1325,16 +1324,15 @@ def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]):
13251324

13261325
@dataclasses.dataclass(frozen=True, eq=False)
13271326
class WindowOpNode(UnaryNode):
1328-
column_name: ex.DerefOp
1329-
op: agg_ops.UnaryWindowOp
1327+
expression: ex.Aggregation
13301328
window_spec: window.WindowSpec
13311329
output_name: bigframes.core.identifiers.ColumnId
13321330
never_skip_nulls: bool = False
13331331
skip_reproject_unsafe: bool = False
13341332

13351333
def _validate(self):
13361334
"""Validate the local data in the node."""
1337-
assert self.column_name.id in self.child.ids
1335+
assert all(ref in self.child.ids for ref in self.expression.column_references)
13381336

13391337
@property
13401338
def non_local(self) -> bool:
@@ -1363,9 +1361,11 @@ def row_count(self) -> Optional[int]:
13631361

13641362
@functools.cached_property
13651363
def added_field(self) -> Field:
1366-
input_type = self.child.get_type(self.column_name.id)
1367-
new_item_dtype = self.op.output_type(input_type)
1368-
return Field(self.output_name, new_item_dtype)
1364+
input_types = self.child._dtype_lookup
1365+
return Field(
1366+
self.output_name,
1367+
bigframes.dtypes.dtype_for_etype(self.expression.output_type(input_types)),
1368+
)
13691369

13701370
@property
13711371
def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]:
@@ -1376,7 +1376,7 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
13761376
return self.child.prune(used_cols)
13771377
consumed_ids = (
13781378
used_cols.difference([self.output_name])
1379-
.union([self.column_name.id])
1379+
.union(self.expression.column_references)
13801380
.union(self.window_spec.all_referenced_columns)
13811381
)
13821382
return self.transform_children(lambda x: x.prune(consumed_ids))
@@ -1391,7 +1391,7 @@ def remap_vars(
13911391
def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]):
13921392
return dataclasses.replace(
13931393
self,
1394-
column_name=self.column_name.remap_column_refs(
1394+
expression=self.expression.remap_column_refs(
13951395
mappings, allow_partial_bindings=True
13961396
),
13971397
window_spec=self.window_spec.remap_column_refs(

bigframes/operations/aggregations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ def skips_nulls(self):
381381

382382
# This should really by a NullaryWindowOp, but APIs don't support that yet.
383383
@dataclasses.dataclass(frozen=True)
384-
class RowNumberOp(UnaryWindowOp):
384+
class RowNumberOp(NullaryWindowOp):
385385
name: ClassVar[str] = "rownumber"
386386

387387
@property

0 commit comments

Comments
 (0)