Skip to content

Commit 4da333e

Browse files
test: Enable floordiv local testing (#1856)
1 parent 80bac0f commit 4da333e

File tree

6 files changed

+168
-15
lines changed

6 files changed

+168
-15
lines changed

bigframes/core/compile/polars/compiler.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,11 @@
2323

2424
import bigframes.core
2525
from bigframes.core import identifiers, nodes, ordering, window_spec
26+
from bigframes.core.compile.polars import lowering
2627
import bigframes.core.expression as ex
2728
import bigframes.core.guid as guid
2829
import bigframes.core.rewrite
30+
import bigframes.core.rewrite.schema_binding
2931
import bigframes.dtypes
3032
import bigframes.operations as ops
3133
import bigframes.operations.aggregations as agg_ops
@@ -403,6 +405,8 @@ def compile(self, array_value: bigframes.core.ArrayValue) -> pl.LazyFrame:
403405
node = bigframes.core.rewrite.column_pruning(node)
404406
node = nodes.bottom_up(node, bigframes.core.rewrite.rewrite_slice)
405407
node = bigframes.core.rewrite.pull_out_window_order(node)
408+
node = bigframes.core.rewrite.schema_binding.bind_schema_to_tree(node)
409+
node = lowering.lower_ops_to_polars(node)
406410
return self.compile_node(node)
407411

408412
@functools.singledispatchmethod
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from bigframes import dtypes
16+
from bigframes.core import bigframe_node, expression
17+
from bigframes.core.rewrite import op_lowering
18+
from bigframes.operations import numeric_ops
19+
import bigframes.operations as ops
20+
21+
# TODO: Would be more precise to actually have separate op set for polars ops (where they diverge from the original ops)
22+
23+
24+
class LowerFloorDivRule(op_lowering.OpLoweringRule):
25+
@property
26+
def op(self) -> type[ops.ScalarOp]:
27+
return numeric_ops.FloorDivOp
28+
29+
def lower(self, expr: expression.OpExpression) -> expression.Expression:
30+
dividend = expr.children[0]
31+
divisor = expr.children[1]
32+
using_floats = (dividend.output_type == dtypes.FLOAT_DTYPE) or (
33+
divisor.output_type == dtypes.FLOAT_DTYPE
34+
)
35+
inf_or_zero = (
36+
expression.const(float("INF")) if using_floats else expression.const(0)
37+
)
38+
zero_result = ops.mul_op.as_expr(inf_or_zero, dividend)
39+
divisor_is_zero = ops.eq_op.as_expr(divisor, expression.const(0))
40+
return ops.where_op.as_expr(zero_result, divisor_is_zero, expr)
41+
42+
43+
POLARS_LOWERING_RULES = (LowerFloorDivRule(),)
44+
45+
46+
def lower_ops_to_polars(root: bigframe_node.BigFrameNode) -> bigframe_node.BigFrameNode:
47+
return op_lowering.lower_ops(root, rules=POLARS_LOWERING_RULES)

bigframes/core/expression.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import functools
2020
import itertools
2121
import typing
22-
from typing import Generator, Mapping, TypeVar, Union
22+
from typing import Callable, Generator, Mapping, TypeVar, Union
2323

2424
import pandas as pd
2525

@@ -249,6 +249,10 @@ def is_identity(self) -> bool:
249249
"""True for identity operation that does not transform input."""
250250
return False
251251

252+
@abc.abstractmethod
253+
def transform_children(self, t: Callable[[Expression], Expression]) -> Expression:
254+
...
255+
252256
def walk(self) -> Generator[Expression, None, None]:
253257
yield self
254258
for child in self.children:
@@ -311,6 +315,9 @@ def __eq__(self, other):
311315

312316
return self.value == other.value and self.dtype == other.dtype
313317

318+
def transform_children(self, t: Callable[[Expression], Expression]) -> Expression:
319+
return self
320+
314321

315322
@dataclasses.dataclass(frozen=True)
316323
class UnboundVariableExpression(Expression):
@@ -362,6 +369,9 @@ def is_bijective(self) -> bool:
362369
def is_identity(self) -> bool:
363370
return True
364371

372+
def transform_children(self, t: Callable[[Expression], Expression]) -> Expression:
373+
return self
374+
365375

366376
@dataclasses.dataclass(frozen=True)
367377
class DerefOp(Expression):
@@ -414,6 +424,9 @@ def is_bijective(self) -> bool:
414424
def is_identity(self) -> bool:
415425
return True
416426

427+
def transform_children(self, t: Callable[[Expression], Expression]) -> Expression:
428+
return self
429+
417430

418431
@dataclasses.dataclass(frozen=True)
419432
class SchemaFieldRefExpression(Expression):
@@ -463,12 +476,15 @@ def is_bijective(self) -> bool:
463476
def is_identity(self) -> bool:
464477
return True
465478

479+
def transform_children(self, t: Callable[[Expression], Expression]) -> Expression:
480+
return self
481+
466482

467483
@dataclasses.dataclass(frozen=True)
468484
class OpExpression(Expression):
469485
"""An expression representing a scalar operation applied to 1 or more argument sub-expressions."""
470486

471-
op: bigframes.operations.RowOp
487+
op: bigframes.operations.ScalarOp
472488
inputs: typing.Tuple[Expression, ...]
473489

474490
@property
@@ -553,6 +569,12 @@ def deterministic(self) -> bool:
553569
all(input.deterministic for input in self.inputs) and self.op.deterministic
554570
)
555571

572+
def transform_children(self, t: Callable[[Expression], Expression]) -> Expression:
573+
new_inputs = tuple(t(input) for input in self.inputs)
574+
if new_inputs != self.inputs:
575+
return dataclasses.replace(self, inputs=new_inputs)
576+
return self
577+
556578

557579
def bind_schema_fields(
558580
expr: Expression, field_by_id: Mapping[ids.ColumnId, field.Field]

bigframes/core/nodes.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,6 +1008,14 @@ def referenced_ids(self) -> COLUMN_SET:
10081008
def _node_expressions(self):
10091009
return (self.predicate,)
10101010

1011+
def transform_exprs(
1012+
self, fn: Callable[[ex.Expression], ex.Expression]
1013+
) -> FilterNode:
1014+
return dataclasses.replace(
1015+
self,
1016+
predicate=fn(self.predicate),
1017+
)
1018+
10111019
def remap_vars(
10121020
self, mappings: Mapping[identifiers.ColumnId, identifiers.ColumnId]
10131021
) -> FilterNode:
@@ -1066,6 +1074,20 @@ def referenced_ids(self) -> COLUMN_SET:
10661074
def _node_expressions(self):
10671075
return tuple(map(lambda x: x.scalar_expression, self.by))
10681076

1077+
def transform_exprs(
1078+
self, fn: Callable[[ex.Expression], ex.Expression]
1079+
) -> OrderByNode:
1080+
new_by = cast(
1081+
tuple[OrderingExpression, ...],
1082+
tuple(
1083+
dataclasses.replace(
1084+
by_expr, scalar_expression=fn(by_expr.scalar_expression)
1085+
)
1086+
for by_expr in self.by
1087+
),
1088+
)
1089+
return dataclasses.replace(self, by=new_by)
1090+
10691091
def remap_vars(
10701092
self, mappings: Mapping[identifiers.ColumnId, identifiers.ColumnId]
10711093
) -> OrderByNode:
@@ -1078,14 +1100,9 @@ def remap_refs(
10781100
itertools.chain.from_iterable(map(lambda x: x.referenced_columns, self.by))
10791101
)
10801102
ref_mapping = {id: ex.DerefOp(mappings[id]) for id in all_refs}
1081-
new_by = cast(
1082-
tuple[OrderingExpression, ...],
1083-
tuple(
1084-
by_expr.bind_refs(ref_mapping, allow_partial_bindings=True)
1085-
for by_expr in self.by
1086-
),
1103+
return self.transform_exprs(
1104+
lambda ex: ex.bind_refs(ref_mapping, allow_partial_bindings=True)
10871105
)
1088-
return dataclasses.replace(self, by=new_by)
10891106

10901107

10911108
@dataclasses.dataclass(frozen=True, eq=False)
@@ -1293,6 +1310,12 @@ def _node_expressions(self):
12931310
def additive_base(self) -> BigFrameNode:
12941311
return self.child
12951312

1313+
def transform_exprs(
1314+
self, fn: Callable[[ex.Expression], ex.Expression]
1315+
) -> ProjectionNode:
1316+
new_fields = tuple((fn(ex), id) for ex, id in self.assignments)
1317+
return dataclasses.replace(self, assignments=new_fields)
1318+
12961319
def replace_additive_base(self, node: BigFrameNode) -> ProjectionNode:
12971320
return dataclasses.replace(self, child=node)
12981321

bigframes/core/rewrite/op_lowering.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from __future__ import annotations
15+
16+
import abc
17+
from typing import Sequence
18+
19+
from bigframes.core import bigframe_node, expression, nodes
20+
import bigframes.operations as ops
21+
22+
23+
class OpLoweringRule(abc.ABC):
24+
@property
25+
@abc.abstractmethod
26+
def op(self) -> type[ops.ScalarOp]:
27+
...
28+
29+
@abc.abstractmethod
30+
def lower(self, expr: expression.OpExpression) -> expression.Expression:
31+
...
32+
33+
34+
def lower_ops(
35+
root: bigframe_node.BigFrameNode, rules: Sequence[OpLoweringRule]
36+
) -> bigframe_node.BigFrameNode:
37+
rules_by_op = {rule.op: rule for rule in rules}
38+
39+
def lower_expr(expr: expression.Expression):
40+
def lower_expr_step(expr: expression.Expression) -> expression.Expression:
41+
if isinstance(expr, expression.OpExpression):
42+
maybe_rule = rules_by_op.get(expr.op.__class__)
43+
if maybe_rule:
44+
return maybe_rule.lower(expr)
45+
return expr
46+
47+
return lower_expr_step(expr.transform_children(lower_expr_step))
48+
49+
def lower_node(node: bigframe_node.BigFrameNode) -> bigframe_node.BigFrameNode:
50+
if isinstance(
51+
node, (nodes.ProjectionNode, nodes.FilterNode, nodes.OrderByNode)
52+
):
53+
return node.transform_exprs(lower_expr)
54+
else:
55+
return node
56+
57+
return root.bottom_up(lower_node)

tests/unit/test_dataframe_polars.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2150,7 +2150,7 @@ def test_df_corrwith_series(scalars_dfs):
21502150
operator.sub,
21512151
operator.mul,
21522152
operator.truediv,
2153-
# operator.floordiv,
2153+
operator.floordiv,
21542154
operator.eq,
21552155
operator.ne,
21562156
operator.gt,
@@ -2163,7 +2163,7 @@ def test_df_corrwith_series(scalars_dfs):
21632163
"subtract",
21642164
"multiply",
21652165
"true_divide",
2166-
# "floor_divide",
2166+
"floor_divide",
21672167
"eq",
21682168
"ne",
21692169
"gt",
@@ -2217,8 +2217,8 @@ def test_scalar_binop_str_exception(scalars_dfs):
22172217
(lambda x, y: x.rmul(y, axis="index")),
22182218
(lambda x, y: x.truediv(y, axis="index")),
22192219
(lambda x, y: x.rtruediv(y, axis="index")),
2220-
# (lambda x, y: x.floordiv(y, axis="index")),
2221-
# (lambda x, y: x.floordiv(y, axis="index")),
2220+
(lambda x, y: x.floordiv(y, axis="index")),
2221+
(lambda x, y: x.floordiv(y, axis="index")),
22222222
(lambda x, y: x.gt(y, axis="index")),
22232223
(lambda x, y: x.ge(y, axis="index")),
22242224
(lambda x, y: x.lt(y, axis="index")),
@@ -2233,8 +2233,8 @@ def test_scalar_binop_str_exception(scalars_dfs):
22332233
"rmul",
22342234
"truediv",
22352235
"rtruediv",
2236-
# "floordiv",
2237-
# "rfloordiv",
2236+
"floordiv",
2237+
"rfloordiv",
22382238
"gt",
22392239
"ge",
22402240
"lt",

0 commit comments

Comments
 (0)