Skip to content

Commit c7aa1af

Browse files
refactor: simplify filter and join nodes (#321)
Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [ ] Make sure to open an issue as a [bug/issue](https://togithub.com/googleapis/python-bigquery-dataframes/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) Fixes #<issue_number_goes_here> 🦕
1 parent b8178b9 commit c7aa1af

File tree

11 files changed

+223
-156
lines changed

11 files changed

+223
-156
lines changed

bigframes/core/__init__.py

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,22 @@
1616
from dataclasses import dataclass
1717
import io
1818
import typing
19-
from typing import Iterable, Literal, Sequence
19+
from typing import Iterable, Sequence
2020

2121
import ibis.expr.types as ibis_types
2222
import pandas
2323

2424
import bigframes.core.compile as compiling
2525
import bigframes.core.expression as ex
2626
import bigframes.core.guid
27+
import bigframes.core.join_def as join_def
2728
import bigframes.core.nodes as nodes
2829
from bigframes.core.ordering import OrderingColumnReference
2930
import bigframes.core.ordering as orderings
3031
import bigframes.core.utils
3132
from bigframes.core.window_spec import WindowSpec
3233
import bigframes.dtypes
34+
import bigframes.operations as ops
3335
import bigframes.operations.aggregations as agg_ops
3436
import bigframes.session._io.bigquery
3537

@@ -114,13 +116,15 @@ def row_count(self) -> ArrayValue:
114116
return ArrayValue(nodes.RowCountNode(child=self.node))
115117

116118
# Operations
117-
def filter(self, predicate_id: str, keep_null: bool = False) -> ArrayValue:
119+
def filter_by_id(self, predicate_id: str, keep_null: bool = False) -> ArrayValue:
118120
"""Filter the table on a given expression, the predicate must be a boolean series aligned with the table expression."""
119-
return ArrayValue(
120-
nodes.FilterNode(
121-
child=self.node, predicate_id=predicate_id, keep_null=keep_null
122-
)
123-
)
121+
predicate = ex.free_var(predicate_id)
122+
if keep_null:
123+
predicate = ops.fillna_op.as_expr(predicate, ex.const(True))
124+
return self.filter(predicate)
125+
126+
def filter(self, predicate: ex.Expression):
127+
return ArrayValue(nodes.FilterNode(child=self.node, predicate=predicate))
124128

125129
def order_by(self, by: Sequence[OrderingColumnReference]) -> ArrayValue:
126130
return ArrayValue(nodes.OrderByNode(child=self.node, by=tuple(by)))
@@ -356,26 +360,15 @@ def unpivot(
356360

357361
def join(
358362
self,
359-
self_column_ids: typing.Sequence[str],
360363
other: ArrayValue,
361-
other_column_ids: typing.Sequence[str],
362-
*,
363-
how: Literal[
364-
"inner",
365-
"left",
366-
"outer",
367-
"right",
368-
"cross",
369-
],
364+
join_def: join_def.JoinDefinition,
370365
allow_row_identity_join: bool = True,
371366
):
372367
return ArrayValue(
373368
nodes.JoinNode(
374369
left_child=self.node,
375370
right_child=other.node,
376-
left_column_ids=tuple(self_column_ids),
377-
right_column_ids=tuple(other_column_ids),
378-
how=how,
371+
join=join_def,
379372
allow_row_identity_join=allow_row_identity_join,
380373
)
381374
)

bigframes/core/blocks.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
import bigframes.core.expression as ex
3939
import bigframes.core.guid as guid
4040
import bigframes.core.indexes as indexes
41-
import bigframes.core.joins.name_resolution as join_names
41+
import bigframes.core.join_def as join_defs
4242
import bigframes.core.ordering as ordering
4343
import bigframes.core.utils
4444
import bigframes.core.utils as utils
@@ -826,7 +826,7 @@ def assign_label(self, column_id: str, new_label: Label) -> Block:
826826

827827
def filter(self, column_id: str, keep_null: bool = False):
828828
return Block(
829-
self._expr.filter(column_id, keep_null),
829+
self._expr.filter_by_id(column_id, keep_null),
830830
index_columns=self.index_columns,
831831
column_labels=self.column_labels,
832832
index_labels=self.index.names,
@@ -1542,19 +1542,38 @@ def merge(
15421542
sort: bool,
15431543
suffixes: tuple[str, str] = ("_x", "_y"),
15441544
) -> Block:
1545-
joined_expr = self.expr.join(
1546-
left_join_ids,
1547-
other.expr,
1548-
right_join_ids,
1549-
how=how,
1550-
)
1551-
get_column_left, get_column_right = join_names.JOIN_NAME_REMAPPER(
1552-
self.expr.column_ids, other.expr.column_ids
1545+
left_mappings = [
1546+
join_defs.JoinColumnMapping(
1547+
source_table=join_defs.JoinSide.LEFT,
1548+
source_id=id,
1549+
destination_id=guid.generate_guid(),
1550+
)
1551+
for id in self.expr.column_ids
1552+
]
1553+
right_mappings = [
1554+
join_defs.JoinColumnMapping(
1555+
source_table=join_defs.JoinSide.RIGHT,
1556+
source_id=id,
1557+
destination_id=guid.generate_guid(),
1558+
)
1559+
for id in other.expr.column_ids
1560+
]
1561+
1562+
join_def = join_defs.JoinDefinition(
1563+
conditions=tuple(
1564+
join_defs.JoinCondition(left, right)
1565+
for left, right in zip(left_join_ids, right_join_ids)
1566+
),
1567+
mappings=(*left_mappings, *right_mappings),
1568+
type=how,
15531569
)
1570+
joined_expr = self.expr.join(other.expr, join_def=join_def)
15541571
result_columns = []
15551572
matching_join_labels = []
15561573

15571574
coalesced_ids = []
1575+
get_column_left = join_def.get_left_mapping()
1576+
get_column_right = join_def.get_right_mapping()
15581577
for left_id, right_id in zip(left_join_ids, right_join_ids):
15591578
coalesced_id = guid.generate_guid()
15601579
joined_expr = joined_expr.project_to_id(

bigframes/core/compile/compiled.py

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ def _reduced_predicate(self) -> typing.Optional[ibis_types.BooleanValue]:
9696
)
9797

9898
@abc.abstractmethod
99-
def filter(self: T, predicate_id: str, keep_null: bool = False) -> T:
100-
"""Filter the table on a given expression, the predicate must be a boolean series aligned with the table expression."""
99+
def filter(self: T, predicate: ex.Expression) -> T:
100+
"""Filter the table on a given expression, the predicate must be a boolean expression."""
101101
...
102102

103103
@abc.abstractmethod
@@ -305,17 +305,9 @@ def _to_ibis_expr(
305305
table = table.filter(ibis.random() < ibis.literal(fraction))
306306
return table
307307

308-
def filter(self, predicate_id: str, keep_null: bool = False) -> UnorderedIR:
309-
condition = typing.cast(
310-
ibis_types.BooleanValue, self._get_ibis_column(predicate_id)
311-
)
312-
if keep_null:
313-
condition = typing.cast(
314-
ibis_types.BooleanValue,
315-
condition.fillna(
316-
typing.cast(ibis_types.BooleanScalar, ibis_types.literal(True))
317-
),
318-
)
308+
def filter(self, predicate: ex.Expression) -> UnorderedIR:
309+
bindings = {col: self._get_ibis_column(col) for col in self.column_ids}
310+
condition = op_compiler.compile_expression(predicate, bindings)
319311
return self._filter(condition)
320312

321313
def _filter(self, predicate_value: ibis_types.BooleanValue) -> UnorderedIR:
@@ -1140,17 +1132,9 @@ def _to_ibis_expr(
11401132
table = table.filter(ibis.random() < ibis.literal(fraction))
11411133
return table
11421134

1143-
def filter(self, predicate_id: str, keep_null: bool = False) -> OrderedIR:
1144-
condition = typing.cast(
1145-
ibis_types.BooleanValue, self._get_ibis_column(predicate_id)
1146-
)
1147-
if keep_null:
1148-
condition = typing.cast(
1149-
ibis_types.BooleanValue,
1150-
condition.fillna(
1151-
typing.cast(ibis_types.BooleanScalar, ibis_types.literal(True))
1152-
),
1153-
)
1135+
def filter(self, predicate: ex.Expression) -> OrderedIR:
1136+
bindings = {col: self._get_ibis_column(col) for col in self.column_ids}
1137+
condition = op_compiler.compile_expression(predicate, bindings)
11541138
return self._filter(condition)
11551139

11561140
def _filter(self, predicate_value: ibis_types.BooleanValue) -> OrderedIR:

bigframes/core/compile/compiler.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -59,22 +59,18 @@ def compile_join(node: nodes.JoinNode, ordered: bool = True):
5959
left_ordered = compile_ordered(node.left_child)
6060
right_ordered = compile_ordered(node.right_child)
6161
return bigframes.core.compile.single_column.join_by_column_ordered(
62-
left_ordered,
63-
node.left_column_ids,
64-
right_ordered,
65-
node.right_column_ids,
66-
how=node.how,
62+
left=left_ordered,
63+
right=right_ordered,
64+
join=node.join,
6765
allow_row_identity_join=node.allow_row_identity_join,
6866
)
6967
else:
7068
left_unordered = compile_unordered(node.left_child)
7169
right_unordered = compile_unordered(node.right_child)
7270
return bigframes.core.compile.single_column.join_by_column_unordered(
73-
left_unordered,
74-
node.left_column_ids,
75-
right_unordered,
76-
node.right_column_ids,
77-
how=node.how,
71+
left=left_unordered,
72+
right=right_unordered,
73+
join=node.join,
7874
allow_row_identity_join=node.allow_row_identity_join,
7975
)
8076

@@ -113,7 +109,7 @@ def compile_promote_offsets(node: nodes.PromoteOffsetsNode, ordered: bool = True
113109

114110
@_compile_node.register
115111
def compile_filter(node: nodes.FilterNode, ordered: bool = True):
116-
return compile_node(node.child, ordered).filter(node.predicate_id, node.keep_null)
112+
return compile_node(node.child, ordered).filter(node.predicate)
117113

118114

119115
@_compile_node.register

bigframes/core/compile/row_identity.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import bigframes.constants as constants
2626
import bigframes.core.compile.compiled as compiled
27+
import bigframes.core.join_def as join_def
2728
import bigframes.core.joins as joining
2829
import bigframes.core.ordering as orderings
2930

@@ -33,11 +34,10 @@
3334
def join_by_row_identity_unordered(
3435
left: compiled.UnorderedIR,
3536
right: compiled.UnorderedIR,
36-
*,
37-
how: str,
37+
join_def: join_def.JoinDefinition,
3838
) -> compiled.UnorderedIR:
3939
"""Compute join when we are joining by row identity not a specific column."""
40-
if how not in SUPPORTED_ROW_IDENTITY_HOW:
40+
if join_def.type not in SUPPORTED_ROW_IDENTITY_HOW:
4141
raise NotImplementedError(
4242
f"Only how='outer','left','inner' currently supported. {constants.FEEDBACK_LINK}"
4343
)
@@ -60,17 +60,20 @@ def join_by_row_identity_unordered(
6060
combined_predicates = []
6161
if left_predicates or right_predicates:
6262
joined_predicates = _join_predicates(
63-
left_predicates, right_predicates, join_type=how
63+
left_predicates, right_predicates, join_type=join_def.type
6464
)
6565
combined_predicates = list(joined_predicates) # builder expects mutable list
6666

67-
left_mask = left_relative_predicates if how in ["right", "outer"] else None
68-
right_mask = right_relative_predicates if how in ["left", "outer"] else None
67+
left_mask = (
68+
left_relative_predicates if join_def.type in ["right", "outer"] else None
69+
)
70+
right_mask = (
71+
right_relative_predicates if join_def.type in ["left", "outer"] else None
72+
)
6973

7074
# Public mapping must use JOIN_NAME_REMAPPER to stay in sync with consumers of join result
71-
map_left_id, map_right_id = joining.JOIN_NAME_REMAPPER(
72-
left.column_ids, right.column_ids
73-
)
75+
map_left_id = join_def.get_left_mapping()
76+
map_right_id = join_def.get_right_mapping()
7477
joined_columns = [
7578
_mask_value(left._get_ibis_column(key), left_mask).name(map_left_id[key])
7679
for key in left.column_ids
@@ -90,11 +93,10 @@ def join_by_row_identity_unordered(
9093
def join_by_row_identity_ordered(
9194
left: compiled.OrderedIR,
9295
right: compiled.OrderedIR,
93-
*,
94-
how: str,
96+
join_def: join_def.JoinDefinition,
9597
) -> compiled.OrderedIR:
9698
"""Compute join when we are joining by row identity not a specific column."""
97-
if how not in SUPPORTED_ROW_IDENTITY_HOW:
99+
if join_def.type not in SUPPORTED_ROW_IDENTITY_HOW:
98100
raise NotImplementedError(
99101
f"Only how='outer','left','inner' currently supported. {constants.FEEDBACK_LINK}"
100102
)
@@ -117,17 +119,20 @@ def join_by_row_identity_ordered(
117119
combined_predicates = []
118120
if left_predicates or right_predicates:
119121
joined_predicates = _join_predicates(
120-
left_predicates, right_predicates, join_type=how
122+
left_predicates, right_predicates, join_type=join_def.type
121123
)
122124
combined_predicates = list(joined_predicates) # builder expects mutable list
123125

124-
left_mask = left_relative_predicates if how in ["right", "outer"] else None
125-
right_mask = right_relative_predicates if how in ["left", "outer"] else None
126+
left_mask = (
127+
left_relative_predicates if join_def.type in ["right", "outer"] else None
128+
)
129+
right_mask = (
130+
right_relative_predicates if join_def.type in ["left", "outer"] else None
131+
)
126132

127133
# Public mapping must use JOIN_NAME_REMAPPER to stay in sync with consumers of join result
128-
lpublicmapping, rpublicmapping = joining.JOIN_NAME_REMAPPER(
129-
left.column_ids, right.column_ids
130-
)
134+
lpublicmapping = join_def.get_left_mapping()
135+
rpublicmapping = join_def.get_right_mapping()
131136
lhiddenmapping, rhiddenmapping = joining.JoinNameRemapper(namespace="hidden")(
132137
left._hidden_column_ids, right._hidden_column_ids
133138
)

0 commit comments

Comments
 (0)