Skip to content

Commit db087b0

Browse files
perf: Improve isin performance (#1203)
1 parent 533db96 commit db087b0

File tree

9 files changed

+388
-57
lines changed

9 files changed

+388
-57
lines changed

bigframes/core/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,18 @@ def project_window_op(
417417
output_name,
418418
)
419419

420+
def isin(
421+
self, other: ArrayValue, lcol: str, rcol: str
422+
) -> typing.Tuple[ArrayValue, str]:
423+
node = nodes.InNode(
424+
self.node,
425+
other.node,
426+
ex.deref(lcol),
427+
ex.deref(rcol),
428+
indicator_col=ids.ColumnId.unique(),
429+
)
430+
return ArrayValue(node), node.indicator_col.name
431+
420432
def relational_join(
421433
self,
422434
other: ArrayValue,

bigframes/core/blocks.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2036,23 +2036,15 @@ def isin(self, other: Block):
20362036
return block
20372037

20382038
def _isin_inner(self: Block, col: str, unique_values: core.ArrayValue) -> Block:
2039-
unique_values, const = unique_values.create_constant(
2040-
True, dtype=bigframes.dtypes.BOOL_DTYPE
2041-
)
2042-
expr, (l_map, r_map) = self._expr.relational_join(
2043-
unique_values, ((col, unique_values.column_ids[0]),), type="left"
2044-
)
2045-
expr, matches = expr.project_to_id(ops.notnull_op.as_expr(r_map[const]))
2039+
expr, matches = self._expr.isin(unique_values, col, unique_values.column_ids[0])
20462040

2047-
new_index_cols = tuple(l_map[idx_col] for idx_col in self.index_columns)
20482041
new_value_cols = tuple(
2049-
l_map[val_col] if val_col != col else matches
2050-
for val_col in self.value_columns
2042+
val_col if val_col != col else matches for val_col in self.value_columns
20512043
)
2052-
expr = expr.select_columns((*new_index_cols, *new_value_cols))
2044+
expr = expr.select_columns((*self.index_columns, *new_value_cols))
20532045
return Block(
20542046
expr,
2055-
index_columns=new_index_cols,
2047+
index_columns=self.index_columns,
20562048
column_labels=self.column_labels,
20572049
index_labels=self._index_labels,
20582050
)

bigframes/core/compile/compiler.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import bigframes.core.compile.concat as concat_impl
2929
import bigframes.core.compile.explode
3030
import bigframes.core.compile.ibis_types
31+
import bigframes.core.compile.isin
3132
import bigframes.core.compile.scalar_op_compiler
3233
import bigframes.core.compile.scalar_op_compiler as compile_scalar
3334
import bigframes.core.compile.schema_translator
@@ -128,6 +129,17 @@ def compile_join(self, node: nodes.JoinNode):
128129
conditions=condition_pairs,
129130
)
130131

132+
@_compile_node.register
133+
def compile_isin(self, node: nodes.InNode):
134+
left_unordered = self.compile_node(node.left_child)
135+
right_unordered = self.compile_node(node.right_child)
136+
return bigframes.core.compile.isin.isin_unordered(
137+
left=left_unordered,
138+
right=right_unordered,
139+
indicator_col=node.indicator_col.sql,
140+
conditions=(node.left_col.id.sql, node.right_col.id.sql),
141+
)
142+
131143
@_compile_node.register
132144
def compile_fromrange(self, node: nodes.FromRangeNode):
133145
# Both start and end are single elements and do not inherently have an order

bigframes/core/compile/isin.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Helpers to join ArrayValue objects."""
16+
17+
from __future__ import annotations
18+
19+
import itertools
20+
from typing import Tuple
21+
22+
import bigframes_vendored.ibis.expr.datatypes as ibis_dtypes
23+
import bigframes_vendored.ibis.expr.types as ibis_types
24+
25+
import bigframes.core.compile.compiled as compiled
26+
27+
28+
def isin_unordered(
29+
left: compiled.UnorderedIR,
30+
right: compiled.UnorderedIR,
31+
indicator_col: str,
32+
conditions: Tuple[str, str],
33+
) -> compiled.UnorderedIR:
34+
"""Join two expressions by column equality.
35+
36+
Arguments:
37+
left: Expression for left table to join.
38+
right: Expression for right table to join.
39+
conditions: Id pairs to compare
40+
Returns:
41+
The joined expression.
42+
"""
43+
left_table = left._to_ibis_expr()
44+
right_table = right._to_ibis_expr()
45+
new_column = (
46+
value_to_join_key(left_table[conditions[0]])
47+
.isin(value_to_join_key(right_table[conditions[1]]))
48+
.name(indicator_col)
49+
)
50+
51+
columns = tuple(
52+
itertools.chain(
53+
(left_table[col.get_name()] for col in left.columns), (new_column,)
54+
)
55+
)
56+
57+
return compiled.UnorderedIR(
58+
left_table,
59+
columns=columns,
60+
)
61+
62+
63+
def value_to_join_key(value: ibis_types.Value):
64+
"""Converts nullable values to non-null string SQL will not match null keys together - but pandas does."""
65+
if not value.type().is_string():
66+
value = value.cast(ibis_dtypes.str)
67+
return (
68+
value.fill_null(ibis_types.literal("$NULL_SENTINEL$"))
69+
if hasattr(value, "fill_null")
70+
else value.fillna(ibis_types.literal("$NULL_SENTINEL$"))
71+
)

bigframes/core/nodes.py

Lines changed: 158 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,12 @@ def explicitly_ordered(self) -> bool:
208208
"""
209209
...
210210

211+
@functools.cached_property
212+
def height(self) -> int:
213+
if len(self.child_nodes) == 0:
214+
return 0
215+
return max(child.height for child in self.child_nodes) + 1
216+
211217
@functools.cached_property
212218
def total_variables(self) -> int:
213219
return self.variables_introduced + sum(
@@ -284,6 +290,34 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
284290
return self.transform_children(lambda x: x.prune(used_cols))
285291

286292

293+
class AdditiveNode:
294+
"""Definition of additive - if you drop added_fields, you end up with the descendent.
295+
296+
.. code-block:: text
297+
298+
AdditiveNode (fields: a, b, c; added_fields: c)
299+
|
300+
| additive_base
301+
V
302+
BigFrameNode (fields: a, b)
303+
304+
"""
305+
306+
@property
307+
@abc.abstractmethod
308+
def added_fields(self) -> Tuple[Field, ...]:
309+
...
310+
311+
@property
312+
@abc.abstractmethod
313+
def additive_base(self) -> BigFrameNode:
314+
...
315+
316+
@abc.abstractmethod
317+
def replace_additive_base(self, BigFrameNode):
318+
...
319+
320+
287321
@dataclasses.dataclass(frozen=True, eq=False)
288322
class UnaryNode(BigFrameNode):
289323
child: BigFrameNode
@@ -381,6 +415,106 @@ def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]):
381415
return self
382416

383417

418+
@dataclasses.dataclass(frozen=True, eq=False)
419+
class InNode(BigFrameNode, AdditiveNode):
420+
"""
421+
Special Join Type that only returns rows from the left side, as well as adding a bool column indicating whether a match exists on the right side.
422+
423+
Modelled separately from join node, as this operation preserves row identity.
424+
"""
425+
426+
left_child: BigFrameNode
427+
right_child: BigFrameNode
428+
left_col: ex.DerefOp
429+
right_col: ex.DerefOp
430+
indicator_col: bfet_ids.ColumnId
431+
432+
def _validate(self):
433+
assert not (
434+
set(self.left_child.ids) & set(self.right_child.ids)
435+
), "Join ids collide"
436+
437+
@property
438+
def row_preserving(self) -> bool:
439+
return False
440+
441+
@property
442+
def non_local(self) -> bool:
443+
return True
444+
445+
@property
446+
def child_nodes(self) -> typing.Sequence[BigFrameNode]:
447+
return (self.left_child, self.right_child)
448+
449+
@property
450+
def order_ambiguous(self) -> bool:
451+
return False
452+
453+
@property
454+
def explicitly_ordered(self) -> bool:
455+
# Preserves left ordering always
456+
return True
457+
458+
@property
459+
def added_fields(self) -> Tuple[Field, ...]:
460+
return (Field(self.indicator_col, bigframes.dtypes.BOOL_DTYPE),)
461+
462+
@property
463+
def fields(self) -> Iterable[Field]:
464+
return itertools.chain(
465+
self.left_child.fields,
466+
self.added_fields,
467+
)
468+
469+
@functools.cached_property
470+
def variables_introduced(self) -> int:
471+
"""Defines the number of variables generated by the current node. Used to estimate query planning complexity."""
472+
return 1
473+
474+
@property
475+
def joins(self) -> bool:
476+
return True
477+
478+
@property
479+
def row_count(self) -> Optional[int]:
480+
return self.left_child.row_count
481+
482+
@property
483+
def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]:
484+
return (self.indicator_col,)
485+
486+
@property
487+
def additive_base(self) -> BigFrameNode:
488+
return self.left_child
489+
490+
def replace_additive_base(self, node: BigFrameNode):
491+
return dataclasses.replace(self, left_child=node)
492+
493+
def transform_children(
494+
self, t: Callable[[BigFrameNode], BigFrameNode]
495+
) -> BigFrameNode:
496+
transformed = dataclasses.replace(
497+
self, left_child=t(self.left_child), right_child=t(self.right_child)
498+
)
499+
if self == transformed:
500+
# reusing existing object speeds up eq, and saves a small amount of memory
501+
return self
502+
return transformed
503+
504+
def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
505+
return self
506+
507+
def remap_vars(
508+
self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]
509+
) -> BigFrameNode:
510+
return dataclasses.replace(
511+
self, indicator_col=mappings.get(self.indicator_col, self.indicator_col)
512+
)
513+
514+
def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]):
515+
return dataclasses.replace(self, left_col=self.left_col.remap_column_refs(mappings, allow_partial_bindings=True), right_col=self.right_col.remap_column_refs(mappings, allow_partial_bindings=True)) # type: ignore
516+
517+
384518
@dataclasses.dataclass(frozen=True, eq=False)
385519
class JoinNode(BigFrameNode):
386520
left_child: BigFrameNode
@@ -926,7 +1060,7 @@ class CachedTableNode(ReadTableNode):
9261060

9271061
# Unary nodes
9281062
@dataclasses.dataclass(frozen=True, eq=False)
929-
class PromoteOffsetsNode(UnaryNode):
1063+
class PromoteOffsetsNode(UnaryNode, AdditiveNode):
9301064
col_id: bigframes.core.identifiers.ColumnId
9311065

9321066
@property
@@ -959,6 +1093,13 @@ def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]:
9591093
def added_fields(self) -> Tuple[Field, ...]:
9601094
return (Field(self.col_id, bigframes.dtypes.INT_DTYPE),)
9611095

1096+
@property
1097+
def additive_base(self) -> BigFrameNode:
1098+
return self.child
1099+
1100+
def replace_additive_base(self, node: BigFrameNode):
1101+
return dataclasses.replace(self, child=node)
1102+
9621103
def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
9631104
if self.col_id not in used_cols:
9641105
return self.child.prune(used_cols)
@@ -1171,7 +1312,7 @@ def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]):
11711312

11721313

11731314
@dataclasses.dataclass(frozen=True, eq=False)
1174-
class ProjectionNode(UnaryNode):
1315+
class ProjectionNode(UnaryNode, AdditiveNode):
11751316
"""Assigns new variables (without modifying existing ones)"""
11761317

11771318
assignments: typing.Tuple[
@@ -1212,6 +1353,13 @@ def row_count(self) -> Optional[int]:
12121353
def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]:
12131354
return tuple(id for _, id in self.assignments)
12141355

1356+
@property
1357+
def additive_base(self) -> BigFrameNode:
1358+
return self.child
1359+
1360+
def replace_additive_base(self, node: BigFrameNode):
1361+
return dataclasses.replace(self, child=node)
1362+
12151363
def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
12161364
pruned_assignments = tuple(i for i in self.assignments if i[1] in used_cols)
12171365
if len(pruned_assignments) == 0:
@@ -1378,7 +1526,7 @@ def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]):
13781526

13791527

13801528
@dataclasses.dataclass(frozen=True, eq=False)
1381-
class WindowOpNode(UnaryNode):
1529+
class WindowOpNode(UnaryNode, AdditiveNode):
13821530
expression: ex.Aggregation
13831531
window_spec: window.WindowSpec
13841532
output_name: bigframes.core.identifiers.ColumnId
@@ -1438,6 +1586,13 @@ def inherits_order(self) -> bool:
14381586
) and self.expression.op.implicitly_inherits_order
14391587
return op_inherits_order or self.window_spec.row_bounded
14401588

1589+
@property
1590+
def additive_base(self) -> BigFrameNode:
1591+
return self.child
1592+
1593+
def replace_additive_base(self, node: BigFrameNode):
1594+
return dataclasses.replace(self, child=node)
1595+
14411596
def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
14421597
if self.output_name not in used_cols:
14431598
return self.child.prune(used_cols)

0 commit comments

Comments
 (0)