Skip to content

Commit 94bb62f

Browse files
committed
Merge remote-tracking branch 'origin/log-adapter-session-scoped-6421369828766099756' into log-adapter-session-scoped-6421369828766099756
2 parents 6d49e47 + 3412370 commit 94bb62f

File tree

13 files changed

+280
-139
lines changed

13 files changed

+280
-139
lines changed

bigframes/core/block_transforms.py

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -625,21 +625,7 @@ def skew(
625625
# counts, moment3 for each column
626626
aggregations = []
627627
for col in original_columns:
628-
delta3_expr = _mean_delta_to_power(3, col)
629-
count_agg = agg_expressions.UnaryAggregation(
630-
agg_ops.count_op,
631-
ex.deref(col),
632-
)
633-
moment3_agg = agg_expressions.UnaryAggregation(
634-
agg_ops.mean_op,
635-
delta3_expr,
636-
)
637-
variance_agg = agg_expressions.UnaryAggregation(
638-
agg_ops.PopVarOp(),
639-
ex.deref(col),
640-
)
641-
skew_expr = _skew_from_moments_and_count(count_agg, moment3_agg, variance_agg)
642-
aggregations.append(skew_expr)
628+
aggregations.append(skew_expr(ex.deref(col)))
643629

644630
block = block.aggregate(
645631
aggregations, grouping_column_ids, column_labels=column_labels
@@ -663,16 +649,7 @@ def kurt(
663649
# counts, moment4 for each column
664650
kurt_exprs = []
665651
for col in original_columns:
666-
delta_4_expr = _mean_delta_to_power(4, col)
667-
count_agg = agg_expressions.UnaryAggregation(agg_ops.count_op, ex.deref(col))
668-
moment4_agg = agg_expressions.UnaryAggregation(agg_ops.mean_op, delta_4_expr)
669-
variance_agg = agg_expressions.UnaryAggregation(
670-
agg_ops.PopVarOp(), ex.deref(col)
671-
)
672-
673-
# Corresponds to order of aggregations in preceding loop
674-
kurt_expr = _kurt_from_moments_and_count(count_agg, moment4_agg, variance_agg)
675-
kurt_exprs.append(kurt_expr)
652+
kurt_exprs.append(kurt_expr(ex.deref(col)))
676653

677654
block = block.aggregate(
678655
kurt_exprs, grouping_column_ids, column_labels=column_labels
@@ -686,13 +663,38 @@ def kurt(
686663
return block
687664

688665

666+
def skew_expr(expr: ex.Expression) -> ex.Expression:
667+
delta3_expr = _mean_delta_to_power(3, expr)
668+
count_agg = agg_expressions.UnaryAggregation(
669+
agg_ops.count_op,
670+
expr,
671+
)
672+
moment3_agg = agg_expressions.UnaryAggregation(
673+
agg_ops.mean_op,
674+
delta3_expr,
675+
)
676+
variance_agg = agg_expressions.UnaryAggregation(
677+
agg_ops.PopVarOp(),
678+
expr,
679+
)
680+
return _skew_from_moments_and_count(count_agg, moment3_agg, variance_agg)
681+
682+
683+
def kurt_expr(expr: ex.Expression) -> ex.Expression:
684+
delta_4_expr = _mean_delta_to_power(4, expr)
685+
count_agg = agg_expressions.UnaryAggregation(agg_ops.count_op, expr)
686+
moment4_agg = agg_expressions.UnaryAggregation(agg_ops.mean_op, delta_4_expr)
687+
variance_agg = agg_expressions.UnaryAggregation(agg_ops.PopVarOp(), expr)
688+
return _kurt_from_moments_and_count(count_agg, moment4_agg, variance_agg)
689+
690+
689691
def _mean_delta_to_power(
690692
n_power: int,
691-
val_id: str,
693+
col_expr: ex.Expression,
692694
) -> ex.Expression:
693695
"""Calculate (x-mean(x))^n. Useful for calculating moment statistics such as skew and kurtosis."""
694-
mean_expr = agg_expressions.UnaryAggregation(agg_ops.mean_op, ex.deref(val_id))
695-
delta = ops.sub_op.as_expr(val_id, mean_expr)
696+
mean_expr = agg_expressions.UnaryAggregation(agg_ops.mean_op, col_expr)
697+
delta = ops.sub_op.as_expr(col_expr, mean_expr)
696698
return ops.pow_op.as_expr(delta, ex.const(n_power))
697699

698700

bigframes/core/bq_data.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,21 @@ def from_table(table: bq.Table, columns: Sequence[str] = ()) -> GbqTable:
6464
else tuple(table.clustering_fields),
6565
)
6666

67+
@staticmethod
68+
def from_ref_and_schema(
69+
table_ref: bq.TableReference,
70+
schema: Sequence[bq.SchemaField],
71+
cluster_cols: Optional[Sequence[str]] = None,
72+
) -> GbqTable:
73+
return GbqTable(
74+
project_id=table_ref.project,
75+
dataset_id=table_ref.dataset_id,
76+
table_id=table_ref.table_id,
77+
physical_schema=tuple(schema),
78+
is_physically_stored=True,
79+
cluster_cols=tuple(cluster_cols) if cluster_cols else None,
80+
)
81+
6782
def get_table_ref(self) -> bq.TableReference:
6883
return bq.TableReference(
6984
bq.DatasetReference(self.project_id, self.dataset_id), self.table_id

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ def compile_window(node: nodes.WindowOpNode, child: ir.SQLGlotIR) -> ir.SQLGlotI
378378
window_op = sge.Case(ifs=when_expressions, default=window_op)
379379

380380
# TODO: check if we can directly window the expression.
381-
result = child.window(
381+
result = result.window(
382382
window_op=window_op,
383383
output_column_id=cdef.id.sql,
384384
)

bigframes/core/compile/sqlglot/expressions/ai_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def _construct_prompt(
9393
for elem in prompt_context:
9494
if elem is None:
9595
prompt.append(exprs[column_ref_idx].expr)
96+
column_ref_idx += 1
9697
else:
9798
prompt.append(sge.Literal.string(elem))
9899

bigframes/core/compile/sqlglot/expressions/json_ops.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,11 @@ def _(expr: TypedExpr) -> sge.Expression:
6969
return sge.func("PARSE_JSON", expr.expr)
7070

7171

72+
@register_unary_op(ops.ToJSON)
73+
def _(expr: TypedExpr) -> sge.Expression:
74+
return sge.func("TO_JSON", expr.expr)
75+
76+
7277
@register_unary_op(ops.ToJSONString)
7378
def _(expr: TypedExpr) -> sge.Expression:
7479
return sge.func("TO_JSON_STRING", expr.expr)

bigframes/core/expression.py

Lines changed: 1 addition & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,11 @@
1515
from __future__ import annotations
1616

1717
import abc
18-
import collections
1918
import dataclasses
2019
import functools
2120
import itertools
2221
import typing
23-
from typing import Callable, Dict, Generator, Mapping, Tuple, TypeVar, Union
22+
from typing import Callable, Generator, Mapping, TypeVar, Union
2423

2524
import pandas as pd
2625

@@ -162,57 +161,6 @@ def walk(self) -> Generator[Expression, None, None]:
162161
for child in self.children:
163162
yield from child.children
164163

165-
def unique_nodes(
166-
self: Expression,
167-
) -> Generator[Expression, None, None]:
168-
"""Walks the tree for unique nodes"""
169-
seen = set()
170-
stack: list[Expression] = [self]
171-
while stack:
172-
item = stack.pop()
173-
if item not in seen:
174-
yield item
175-
seen.add(item)
176-
stack.extend(item.children)
177-
178-
def iter_nodes_topo(
179-
self: Expression,
180-
) -> Generator[Expression, None, None]:
181-
"""Returns nodes in reverse topological order, using Kahn's algorithm."""
182-
child_to_parents: Dict[Expression, list[Expression]] = collections.defaultdict(
183-
list
184-
)
185-
out_degree: Dict[Expression, int] = collections.defaultdict(int)
186-
187-
queue: collections.deque["Expression"] = collections.deque()
188-
for node in list(self.unique_nodes()):
189-
num_children = len(node.children)
190-
out_degree[node] = num_children
191-
if num_children == 0:
192-
queue.append(node)
193-
for child in node.children:
194-
child_to_parents[child].append(node)
195-
196-
while queue:
197-
item = queue.popleft()
198-
yield item
199-
parents = child_to_parents.get(item, [])
200-
for parent in parents:
201-
out_degree[parent] -= 1
202-
if out_degree[parent] == 0:
203-
queue.append(parent)
204-
205-
def reduce_up(self, reduction: Callable[[Expression, Tuple[T, ...]], T]) -> T:
206-
"""Apply a bottom-up reduction to the tree."""
207-
results: dict[Expression, T] = {}
208-
for node in list(self.iter_nodes_topo()):
209-
# child nodes have already been transformed
210-
child_results = tuple(results[child] for child in node.children)
211-
result = reduction(node, child_results)
212-
results[node] = result
213-
214-
return results[self]
215-
216164

217165
@dataclasses.dataclass(frozen=True)
218166
class ScalarConstantExpression(Expression):

bigframes/core/expression_factoring.py

Lines changed: 80 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818
import functools
1919
import itertools
2020
from typing import (
21+
Callable,
2122
cast,
23+
Dict,
24+
Generator,
2225
Hashable,
2326
Iterable,
2427
Iterator,
@@ -40,18 +43,72 @@
4043

4144
_MAX_INLINE_COMPLEXITY = 10
4245

46+
T = TypeVar("T")
47+
48+
49+
def unique_nodes(
50+
roots: Sequence[expression.Expression],
51+
) -> Generator[expression.Expression, None, None]:
52+
"""Walks the tree for unique nodes"""
53+
seen = set()
54+
stack: list[expression.Expression] = list(roots)
55+
while stack:
56+
item = stack.pop()
57+
if item not in seen:
58+
yield item
59+
seen.add(item)
60+
stack.extend(item.children)
61+
62+
63+
def iter_nodes_topo(
64+
roots: Sequence[expression.Expression],
65+
) -> Generator[expression.Expression, None, None]:
66+
"""Returns nodes in reverse topological order, using Kahn's algorithm."""
67+
child_to_parents: Dict[
68+
expression.Expression, list[expression.Expression]
69+
] = collections.defaultdict(list)
70+
out_degree: Dict[expression.Expression, int] = collections.defaultdict(int)
71+
72+
queue: collections.deque[expression.Expression] = collections.deque()
73+
for node in unique_nodes(roots):
74+
num_children = len(node.children)
75+
out_degree[node] = num_children
76+
if num_children == 0:
77+
queue.append(node)
78+
for child in node.children:
79+
child_to_parents[child].append(node)
80+
81+
while queue:
82+
item = queue.popleft()
83+
yield item
84+
parents = child_to_parents.get(item, [])
85+
for parent in parents:
86+
out_degree[parent] -= 1
87+
if out_degree[parent] == 0:
88+
queue.append(parent)
89+
90+
91+
def reduce_up(
92+
roots: Sequence[expression.Expression],
93+
reduction: Callable[[expression.Expression, Tuple[T, ...]], T],
94+
) -> Tuple[T, ...]:
95+
"""Apply a bottom-up reduction to the forest."""
96+
results: dict[expression.Expression, T] = {}
97+
for node in list(iter_nodes_topo(roots)):
98+
# child nodes have already been transformed
99+
child_results = tuple(results[child] for child in node.children)
100+
result = reduction(node, child_results)
101+
results[node] = result
102+
103+
return tuple(results[root] for root in roots)
104+
43105

44106
def apply_col_exprs_to_plan(
45107
plan: nodes.BigFrameNode, col_exprs: Sequence[nodes.ColumnDef]
46108
) -> nodes.BigFrameNode:
47-
# TODO: Jointly fragmentize expressions to more efficiently reuse common sub-expressions
48109
target_ids = tuple(named_expr.id for named_expr in col_exprs)
49110

50-
fragments = tuple(
51-
itertools.chain.from_iterable(
52-
fragmentize_expression(expr) for expr in col_exprs
53-
)
54-
)
111+
fragments = fragmentize_expression(col_exprs)
55112
return push_into_tree(plan, fragments, target_ids)
56113

57114

@@ -101,14 +158,26 @@ class FactoredExpression:
101158
sub_exprs: Tuple[nodes.ColumnDef, ...]
102159

103160

104-
def fragmentize_expression(root: nodes.ColumnDef) -> Sequence[nodes.ColumnDef]:
161+
def fragmentize_expression(
162+
roots: Sequence[nodes.ColumnDef],
163+
) -> Sequence[nodes.ColumnDef]:
105164
"""
106165
The goal of this functions is to factor out an expression into multiple sub-expressions.
107166
"""
108-
109-
factored_expr = root.expression.reduce_up(gather_fragments)
110-
root_expr = nodes.ColumnDef(factored_expr.root_expr, root.id)
111-
return (root_expr, *factored_expr.sub_exprs)
167+
# TODO: Fragmentize a bit less aggressively
168+
factored_exprs = reduce_up([root.expression for root in roots], gather_fragments)
169+
root_exprs = (
170+
nodes.ColumnDef(factored.root_expr, root.id)
171+
for factored, root in zip(factored_exprs, roots)
172+
)
173+
return (
174+
*root_exprs,
175+
*dedupe(
176+
itertools.chain.from_iterable(
177+
factored_expr.sub_exprs for factored_expr in factored_exprs
178+
)
179+
),
180+
)
112181

113182

114183
@dataclasses.dataclass(frozen=True, eq=False)

bigframes/core/local_data.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,9 @@ def _append_offsets(
486486
) -> Iterable[pa.RecordBatch]:
487487
offset = 0
488488
for batch in batches:
489-
offsets = pa.array(range(offset, offset + batch.num_rows), type=pa.int64())
489+
offsets = pa.array(
490+
range(offset, offset + batch.num_rows), size=batch.num_rows, type=pa.int64()
491+
)
490492
batch_w_offsets = pa.record_batch(
491493
[*batch.columns, offsets],
492494
schema=batch.schema.append(pa.field(offsets_col_name, pa.int64())),

0 commit comments

Comments
 (0)