Skip to content

Commit 935af10

Browse files
fix: Fix scalar op lowering tree walk (#2029)
1 parent c0b54f0 commit 935af10

File tree

3 files changed

+21
-1
lines changed

3 files changed

+21
-1
lines changed

bigframes/core/expression.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,11 @@ def is_identity(self) -> bool:
253253
def transform_children(self, t: Callable[[Expression], Expression]) -> Expression:
254254
...
255255

256+
def bottom_up(self, t: Callable[[Expression], Expression]) -> Expression:
257+
expr = self.transform_children(lambda child: child.bottom_up(t))
258+
expr = t(expr)
259+
return expr
260+
256261
def walk(self) -> Generator[Expression, None, None]:
257262
yield self
258263
for child in self.children:

bigframes/core/rewrite/op_lowering.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def lower_expr_step(expr: expression.Expression) -> expression.Expression:
4444
return maybe_rule.lower(expr)
4545
return expr
4646

47-
return lower_expr_step(expr.transform_children(lower_expr_step))
47+
return expr.bottom_up(lower_expr_step)
4848

4949
def lower_node(node: bigframe_node.BigFrameNode) -> bigframe_node.BigFrameNode:
5050
if isinstance(

tests/system/small/engines/test_generic_ops.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,3 +423,18 @@ def test_engines_isin_op(scalars_array_value: array_value.ArrayValue, engine):
423423
)
424424

425425
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
426+
427+
428+
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
429+
def test_engines_isin_op_nested_filter(
430+
scalars_array_value: array_value.ArrayValue, engine
431+
):
432+
isin_clause = ops.IsInOp((1, 2, 3)).as_expr(expression.deref("int64_col"))
433+
filter_clause = ops.invert_op.as_expr(
434+
ops.or_op.as_expr(
435+
expression.deref("bool_col"), ops.invert_op.as_expr(isin_clause)
436+
)
437+
)
438+
arr = scalars_array_value.filter(filter_clause)
439+
440+
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)

0 commit comments

Comments
 (0)