Skip to content

Commit 62d4314

Browse files
Merge branch 'main' into session_simplify
2 parents f8da6f8 + 7072627 commit 62d4314

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+1370
-204
lines changed

bigframes/core/bigframe_node.py

Lines changed: 20 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,7 @@
2020
import functools
2121
import itertools
2222
import typing
23-
from typing import (
24-
Callable,
25-
Dict,
26-
Generator,
27-
Iterable,
28-
Mapping,
29-
Sequence,
30-
Set,
31-
Tuple,
32-
Union,
33-
)
23+
from typing import Callable, Dict, Generator, Iterable, Mapping, Sequence, Tuple, Union
3424

3525
from bigframes.core import expression, field, identifiers
3626
import bigframes.core.schema as schemata
@@ -309,33 +299,31 @@ def unique_nodes(
309299
seen.add(item)
310300
stack.extend(item.child_nodes)
311301

312-
def edges(
302+
def iter_nodes_topo(
313303
self: BigFrameNode,
314-
) -> Generator[Tuple[BigFrameNode, BigFrameNode], None, None]:
315-
for item in self.unique_nodes():
316-
for child in item.child_nodes:
317-
yield (item, child)
318-
319-
def iter_nodes_topo(self: BigFrameNode) -> Generator[BigFrameNode, None, None]:
320-
"""Returns nodes from bottom up."""
321-
queue = collections.deque(
322-
[node for node in self.unique_nodes() if not node.child_nodes]
323-
)
324-
304+
) -> Generator[BigFrameNode, None, None]:
305+
"""Returns nodes in reverse topological order, using Kahn's algorithm."""
325306
child_to_parents: Dict[
326-
BigFrameNode, Set[BigFrameNode]
327-
] = collections.defaultdict(set)
328-
for parent, child in self.edges():
329-
child_to_parents[child].add(parent)
330-
331-
yielded = set()
307+
BigFrameNode, list[BigFrameNode]
308+
] = collections.defaultdict(list)
309+
out_degree: Dict[BigFrameNode, int] = collections.defaultdict(int)
310+
311+
queue: collections.deque["BigFrameNode"] = collections.deque()
312+
for node in list(self.unique_nodes()):
313+
num_children = len(node.child_nodes)
314+
out_degree[node] = num_children
315+
if num_children == 0:
316+
queue.append(node)
317+
for child in node.child_nodes:
318+
child_to_parents[child].append(node)
332319

333320
while queue:
334321
item = queue.popleft()
335322
yield item
336-
yielded.add(item)
337-
for parent in child_to_parents[item]:
338-
if set(parent.child_nodes).issubset(yielded):
323+
parents = child_to_parents.get(item, [])
324+
for parent in parents:
325+
out_degree[parent] -= 1
326+
if out_degree[parent] == 0:
339327
queue.append(parent)
340328

341329
def top_down(

bigframes/core/blocks.py

Lines changed: 4 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1266,46 +1266,10 @@ def aggregate_all_and_stack(
12661266
index_labels=[None],
12671267
).transpose(original_row_index=pd.Index([None]), single_row_mode=True)
12681268
else: # axis_n == 1
1269-
# using offsets as identity to group on.
1270-
# TODO: Allow to promote identity/total_order columns instead for better perf
1271-
expr_with_offsets, offset_col = self.expr.promote_offsets()
1272-
stacked_expr, (_, value_col_ids, passthrough_cols,) = unpivot(
1273-
expr_with_offsets,
1274-
row_labels=self.column_labels,
1275-
unpivot_columns=[tuple(self.value_columns)],
1276-
passthrough_columns=[*self.index_columns, offset_col],
1277-
)
1278-
# these corresponed to passthrough_columns provided to unpivot
1279-
index_cols = passthrough_cols[:-1]
1280-
og_offset_col = passthrough_cols[-1]
1281-
index_aggregations = [
1282-
(
1283-
ex.UnaryAggregation(agg_ops.AnyValueOp(), ex.deref(col_id)),
1284-
col_id,
1285-
)
1286-
for col_id in index_cols
1287-
]
1288-
# TODO: may need add NullaryAggregation in main_aggregation
1289-
# when agg add support for axis=1, needed for agg("size", axis=1)
1290-
assert isinstance(
1291-
operation, agg_ops.UnaryAggregateOp
1292-
), f"Expected a unary operation, but got {operation}. Please report this error and how you got here to the BigQuery DataFrames team (bit.ly/bigframes-feedback)."
1293-
main_aggregation = (
1294-
ex.UnaryAggregation(operation, ex.deref(value_col_ids[0])),
1295-
value_col_ids[0],
1296-
)
1297-
# Drop row identity after aggregating over it
1298-
result_expr = stacked_expr.aggregate(
1299-
[*index_aggregations, main_aggregation],
1300-
by_column_ids=[og_offset_col],
1301-
dropna=dropna,
1302-
).drop_columns([og_offset_col])
1303-
return Block(
1304-
result_expr,
1305-
index_columns=index_cols,
1306-
column_labels=[None],
1307-
index_labels=self.index.names,
1308-
)
1269+
as_array = ops.ToArrayOp().as_expr(*(col for col in self.value_columns))
1270+
reduced = ops.ArrayReduceOp(operation).as_expr(as_array)
1271+
block, id = self.project_expr(reduced, None)
1272+
return block.select_column(id)
13091273

13101274
def aggregate_size(
13111275
self,

bigframes/core/compile/ibis_compiler/aggregate_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def _(
165165
) -> ibis_types.NumericValue:
166166
# Will be null if all inputs are null. Pandas defaults to zero sum though.
167167
bq_sum = _apply_window_if_present(column.sum(), window)
168-
return bq_sum.fill_null(ibis_types.literal(0))
168+
return bq_sum.coalesce(ibis_types.literal(0))
169169

170170

171171
@compile_unary_agg.register

bigframes/core/compile/ibis_compiler/scalar_op_registry.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1201,6 +1201,28 @@ def array_slice_op_impl(x: ibis_types.Value, op: ops.ArraySliceOp):
12011201
return res
12021202

12031203

1204+
@scalar_op_compiler.register_nary_op(ops.ToArrayOp, pass_op=False)
1205+
def to_arry_op_impl(*values: ibis_types.Value):
1206+
do_upcast_bool = any(t.type().is_numeric() for t in values)
1207+
if do_upcast_bool:
1208+
values = tuple(
1209+
val.cast(ibis_dtypes.int64) if val.type().is_boolean() else val
1210+
for val in values
1211+
)
1212+
return ibis_api.array(values)
1213+
1214+
1215+
@scalar_op_compiler.register_unary_op(ops.ArrayReduceOp, pass_op=True)
1216+
def array_reduce_op_impl(x: ibis_types.Value, op: ops.ArrayReduceOp):
1217+
import bigframes.core.compile.ibis_compiler.aggregate_compiler as agg_compilers
1218+
1219+
return typing.cast(ibis_types.ArrayValue, x).reduce(
1220+
lambda arr_vals: agg_compilers.compile_unary_agg(
1221+
op.aggregation, typing.cast(ibis_types.Column, arr_vals)
1222+
)
1223+
)
1224+
1225+
12041226
# JSON Ops
12051227
@scalar_op_compiler.register_binary_op(ops.JSONSet, pass_op=True)
12061228
def json_set_op_impl(x: ibis_types.Value, y: ibis_types.Value, op: ops.JSONSet):

bigframes/core/compile/polars/compiler.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,12 @@
3131
import bigframes.dtypes
3232
import bigframes.operations as ops
3333
import bigframes.operations.aggregations as agg_ops
34+
import bigframes.operations.array_ops as arr_ops
3435
import bigframes.operations.bool_ops as bool_ops
3536
import bigframes.operations.comparison_ops as comp_ops
37+
import bigframes.operations.date_ops as date_ops
3638
import bigframes.operations.datetime_ops as dt_ops
39+
import bigframes.operations.frequency_ops as freq_ops
3740
import bigframes.operations.generic_ops as gen_ops
3841
import bigframes.operations.json_ops as json_ops
3942
import bigframes.operations.numeric_ops as num_ops
@@ -74,6 +77,20 @@ def decorator(func):
7477

7578

7679
if polars_installed:
80+
_FREQ_MAPPING = {
81+
"Y": "1y",
82+
"Q": "1q",
83+
"M": "1mo",
84+
"W": "1w",
85+
"D": "1d",
86+
"h": "1h",
87+
"min": "1m",
88+
"s": "1s",
89+
"ms": "1ms",
90+
"us": "1us",
91+
"ns": "1ns",
92+
}
93+
7794
_DTYPE_MAPPING = {
7895
# Direct mappings
7996
bigframes.dtypes.INT_DTYPE: pl.Int64(),
@@ -301,11 +318,76 @@ def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr:
301318
assert isinstance(op, string_ops.StrConcatOp)
302319
return pl.concat_str(l_input, r_input)
303320

321+
@compile_op.register(string_ops.StrContainsOp)
322+
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
323+
assert isinstance(op, string_ops.StrContainsOp)
324+
return input.str.contains(pattern=op.pat, literal=True)
325+
326+
@compile_op.register(string_ops.StrContainsRegexOp)
327+
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
328+
assert isinstance(op, string_ops.StrContainsRegexOp)
329+
return input.str.contains(pattern=op.pat, literal=False)
330+
331+
@compile_op.register(string_ops.StartsWithOp)
332+
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
333+
assert isinstance(op, string_ops.StartsWithOp)
334+
if len(op.pat) == 1:
335+
return input.str.starts_with(op.pat[0])
336+
else:
337+
return pl.any_horizontal(
338+
*(input.str.starts_with(pat) for pat in op.pat)
339+
)
340+
341+
@compile_op.register(string_ops.EndsWithOp)
342+
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
343+
assert isinstance(op, string_ops.EndsWithOp)
344+
if len(op.pat) == 1:
345+
return input.str.ends_with(op.pat[0])
346+
else:
347+
return pl.any_horizontal(*(input.str.ends_with(pat) for pat in op.pat))
348+
349+
@compile_op.register(freq_ops.FloorDtOp)
350+
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
351+
assert isinstance(op, freq_ops.FloorDtOp)
352+
return input.dt.truncate(every=_FREQ_MAPPING[op.freq])
353+
304354
@compile_op.register(dt_ops.StrftimeOp)
305355
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
306356
assert isinstance(op, dt_ops.StrftimeOp)
307357
return input.dt.strftime(op.date_format)
308358

359+
@compile_op.register(date_ops.YearOp)
360+
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
361+
return input.dt.year()
362+
363+
@compile_op.register(date_ops.QuarterOp)
364+
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
365+
return input.dt.quarter()
366+
367+
@compile_op.register(date_ops.MonthOp)
368+
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
369+
return input.dt.month()
370+
371+
@compile_op.register(date_ops.DayOfWeekOp)
372+
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
373+
return input.dt.weekday() - 1
374+
375+
@compile_op.register(date_ops.DayOp)
376+
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
377+
return input.dt.day()
378+
379+
@compile_op.register(date_ops.IsoYearOp)
380+
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
381+
return input.dt.iso_year()
382+
383+
@compile_op.register(date_ops.IsoWeekOp)
384+
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
385+
return input.dt.week()
386+
387+
@compile_op.register(date_ops.IsoDayOp)
388+
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
389+
return input.dt.weekday()
390+
309391
@compile_op.register(dt_ops.ParseDatetimeOp)
310392
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
311393
assert isinstance(op, dt_ops.ParseDatetimeOp)
@@ -325,6 +407,36 @@ def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
325407
assert isinstance(op, json_ops.JSONDecode)
326408
return input.str.json_decode(_DTYPE_MAPPING[op.to_type])
327409

410+
@compile_op.register(arr_ops.ToArrayOp)
411+
def _(self, op: ops.ToArrayOp, *inputs: pl.Expr) -> pl.Expr:
412+
return pl.concat_list(*inputs)
413+
414+
@compile_op.register(arr_ops.ArrayReduceOp)
415+
def _(self, op: ops.ArrayReduceOp, input: pl.Expr) -> pl.Expr:
416+
# TODO: Unify this with general aggregation compilation?
417+
if isinstance(op.aggregation, agg_ops.MinOp):
418+
return input.list.min()
419+
if isinstance(op.aggregation, agg_ops.MaxOp):
420+
return input.list.max()
421+
if isinstance(op.aggregation, agg_ops.SumOp):
422+
return input.list.sum()
423+
if isinstance(op.aggregation, agg_ops.MeanOp):
424+
return input.list.mean()
425+
if isinstance(op.aggregation, agg_ops.CountOp):
426+
return input.list.len()
427+
if isinstance(op.aggregation, agg_ops.StdOp):
428+
return input.list.std()
429+
if isinstance(op.aggregation, agg_ops.VarOp):
430+
return input.list.var()
431+
if isinstance(op.aggregation, agg_ops.AnyOp):
432+
return input.list.any()
433+
if isinstance(op.aggregation, agg_ops.AllOp):
434+
return input.list.all()
435+
else:
436+
raise NotImplementedError(
437+
f"Haven't implemented array aggregation: {op.aggregation}"
438+
)
439+
328440
@dataclasses.dataclass(frozen=True)
329441
class PolarsAggregateCompiler:
330442
scalar_compiler = PolarsExpressionCompiler()

bigframes/core/compile/sqlglot/expressions/unary_compiler.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
import functools
1718
import typing
1819

1920
import pandas as pd
@@ -292,6 +293,18 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
292293
return sge.Extract(this=sge.Identifier(this="DAYOFYEAR"), expression=expr.expr)
293294

294295

296+
@UNARY_OP_REGISTRATION.register(ops.EndsWithOp)
297+
def _(op: ops.EndsWithOp, expr: TypedExpr) -> sge.Expression:
298+
if not op.pat:
299+
return sge.false()
300+
301+
def to_endswith(pat: str) -> sge.Expression:
302+
return sge.func("ENDS_WITH", expr.expr, sge.convert(pat))
303+
304+
conditions = [to_endswith(pat) for pat in op.pat]
305+
return functools.reduce(lambda x, y: sge.Or(this=x, expression=y), conditions)
306+
307+
295308
@UNARY_OP_REGISTRATION.register(ops.exp_op)
296309
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
297310
return sge.Case(
@@ -633,6 +646,18 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
633646
)
634647

635648

649+
@UNARY_OP_REGISTRATION.register(ops.StartsWithOp)
650+
def _(op: ops.StartsWithOp, expr: TypedExpr) -> sge.Expression:
651+
if not op.pat:
652+
return sge.false()
653+
654+
def to_startswith(pat: str) -> sge.Expression:
655+
return sge.func("STARTS_WITH", expr.expr, sge.convert(pat))
656+
657+
conditions = [to_startswith(pat) for pat in op.pat]
658+
return functools.reduce(lambda x, y: sge.Or(this=x, expression=y), conditions)
659+
660+
636661
@UNARY_OP_REGISTRATION.register(ops.StrStripOp)
637662
def _(op: ops.StrStripOp, expr: TypedExpr) -> sge.Expression:
638663
return sge.Trim(this=sge.convert(op.to_strip), expression=expr.expr)
@@ -656,6 +681,11 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
656681
)
657682

658683

684+
@UNARY_OP_REGISTRATION.register(ops.StringSplitOp)
685+
def _(op: ops.StringSplitOp, expr: TypedExpr) -> sge.Expression:
686+
return sge.Split(this=expr.expr, expression=sge.convert(op.pat))
687+
688+
659689
@UNARY_OP_REGISTRATION.register(ops.StrGetOp)
660690
def _(op: ops.StrGetOp, expr: TypedExpr) -> sge.Expression:
661691
return sge.Substring(
@@ -808,3 +838,31 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
808838
@UNARY_OP_REGISTRATION.register(ops.year_op)
809839
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
810840
return sge.Extract(this=sge.Identifier(this="YEAR"), expression=expr.expr)
841+
842+
843+
@UNARY_OP_REGISTRATION.register(ops.ZfillOp)
844+
def _(op: ops.ZfillOp, expr: TypedExpr) -> sge.Expression:
845+
return sge.Case(
846+
ifs=[
847+
sge.If(
848+
this=sge.EQ(
849+
this=sge.Substring(
850+
this=expr.expr, start=sge.convert(1), length=sge.convert(1)
851+
),
852+
expression=sge.convert("-"),
853+
),
854+
true=sge.Concat(
855+
expressions=[
856+
sge.convert("-"),
857+
sge.func(
858+
"LPAD",
859+
sge.Substring(this=expr.expr, start=sge.convert(1)),
860+
sge.convert(op.width - 1),
861+
sge.convert("0"),
862+
),
863+
]
864+
),
865+
)
866+
],
867+
default=sge.func("LPAD", expr.expr, sge.convert(op.width), sge.convert("0")),
868+
)

0 commit comments

Comments
 (0)