Skip to content

Commit ebdcd02

Browse files
feat: Allow local arithmetic execution in hybrid engine (#1906)
1 parent c15cb8a commit ebdcd02

File tree

7 files changed

+520
-43
lines changed

7 files changed

+520
-43
lines changed

bigframes/core/compile/polars/compiler.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import bigframes.operations.comparison_ops as comp_ops
3636
import bigframes.operations.generic_ops as gen_ops
3737
import bigframes.operations.numeric_ops as num_ops
38+
import bigframes.operations.string_ops as string_ops
3839

3940
polars_installed = True
4041
if TYPE_CHECKING:
@@ -146,6 +147,14 @@ def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
146147
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
147148
return input.abs()
148149

150+
@compile_op.register(num_ops.FloorOp)
151+
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
152+
return input.floor()
153+
154+
@compile_op.register(num_ops.CeilOp)
155+
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
156+
return input.ceil()
157+
149158
@compile_op.register(num_ops.PosOp)
150159
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
151160
return input.__pos__()
@@ -182,10 +191,6 @@ def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr:
182191
def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr:
183192
return l_input // r_input
184193

185-
@compile_op.register(num_ops.FloorDivOp)
186-
def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr:
187-
return l_input // r_input
188-
189194
@compile_op.register(num_ops.ModOp)
190195
def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr:
191196
return l_input % r_input
@@ -270,6 +275,11 @@ def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
270275
# eg. We want "True" instead of "true" for bool to strin
271276
return input.cast(_DTYPE_MAPPING[op.to_type], strict=not op.safe)
272277

278+
@compile_op.register(string_ops.StrConcatOp)
279+
def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr:
280+
assert isinstance(op, string_ops.StrConcatOp)
281+
return pl.concat_str(l_input, r_input)
282+
273283
@dataclasses.dataclass(frozen=True)
274284
class PolarsAggregateCompiler:
275285
scalar_compiler = PolarsExpressionCompiler()

bigframes/core/compile/polars/lowering.py

Lines changed: 248 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,26 +37,259 @@ def lower(self, expr: expression.OpExpression) -> expression.Expression:
3737
return expr.op.as_expr(larg, rarg)
3838

3939

40+
class LowerAddRule(op_lowering.OpLoweringRule):
41+
@property
42+
def op(self) -> type[ops.ScalarOp]:
43+
return numeric_ops.AddOp
44+
45+
def lower(self, expr: expression.OpExpression) -> expression.Expression:
46+
assert isinstance(expr.op, numeric_ops.AddOp)
47+
larg, rarg = expr.children[0], expr.children[1]
48+
49+
if (
50+
larg.output_type == dtypes.BOOL_DTYPE
51+
and rarg.output_type == dtypes.BOOL_DTYPE
52+
):
53+
int_result = expr.op.as_expr(
54+
ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(larg),
55+
ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(rarg),
56+
)
57+
return ops.AsTypeOp(to_type=dtypes.BOOL_DTYPE).as_expr(int_result)
58+
59+
if dtypes.is_string_like(larg.output_type) and dtypes.is_string_like(
60+
rarg.output_type
61+
):
62+
return ops.strconcat_op.as_expr(larg, rarg)
63+
64+
if larg.output_type == dtypes.BOOL_DTYPE:
65+
larg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(larg)
66+
if rarg.output_type == dtypes.BOOL_DTYPE:
67+
rarg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(rarg)
68+
69+
if (
70+
larg.output_type == dtypes.DATE_DTYPE
71+
and rarg.output_type == dtypes.TIMEDELTA_DTYPE
72+
):
73+
larg = ops.AsTypeOp(to_type=dtypes.DATETIME_DTYPE).as_expr(larg)
74+
75+
if (
76+
larg.output_type == dtypes.TIMEDELTA_DTYPE
77+
and rarg.output_type == dtypes.DATE_DTYPE
78+
):
79+
rarg = ops.AsTypeOp(to_type=dtypes.DATETIME_DTYPE).as_expr(rarg)
80+
81+
return expr.op.as_expr(larg, rarg)
82+
83+
84+
class LowerSubRule(op_lowering.OpLoweringRule):
85+
@property
86+
def op(self) -> type[ops.ScalarOp]:
87+
return numeric_ops.SubOp
88+
89+
def lower(self, expr: expression.OpExpression) -> expression.Expression:
90+
assert isinstance(expr.op, numeric_ops.SubOp)
91+
larg, rarg = expr.children[0], expr.children[1]
92+
93+
if (
94+
larg.output_type == dtypes.BOOL_DTYPE
95+
and rarg.output_type == dtypes.BOOL_DTYPE
96+
):
97+
int_result = expr.op.as_expr(
98+
ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(larg),
99+
ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(rarg),
100+
)
101+
return ops.AsTypeOp(to_type=dtypes.BOOL_DTYPE).as_expr(int_result)
102+
103+
if larg.output_type == dtypes.BOOL_DTYPE:
104+
larg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(larg)
105+
if rarg.output_type == dtypes.BOOL_DTYPE:
106+
rarg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(rarg)
107+
108+
if (
109+
larg.output_type == dtypes.DATE_DTYPE
110+
and rarg.output_type == dtypes.TIMEDELTA_DTYPE
111+
):
112+
larg = ops.AsTypeOp(to_type=dtypes.DATETIME_DTYPE).as_expr(larg)
113+
114+
return expr.op.as_expr(larg, rarg)
115+
116+
117+
@dataclasses.dataclass
118+
class LowerMulRule(op_lowering.OpLoweringRule):
119+
@property
120+
def op(self) -> type[ops.ScalarOp]:
121+
return numeric_ops.MulOp
122+
123+
def lower(self, expr: expression.OpExpression) -> expression.Expression:
124+
assert isinstance(expr.op, numeric_ops.MulOp)
125+
larg, rarg = expr.children[0], expr.children[1]
126+
127+
if (
128+
larg.output_type == dtypes.BOOL_DTYPE
129+
and rarg.output_type == dtypes.BOOL_DTYPE
130+
):
131+
int_result = expr.op.as_expr(
132+
ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(larg),
133+
ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(rarg),
134+
)
135+
return ops.AsTypeOp(to_type=dtypes.BOOL_DTYPE).as_expr(int_result)
136+
137+
if (
138+
larg.output_type == dtypes.BOOL_DTYPE
139+
and rarg.output_type != dtypes.BOOL_DTYPE
140+
):
141+
larg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(larg)
142+
if (
143+
rarg.output_type == dtypes.BOOL_DTYPE
144+
and larg.output_type != dtypes.BOOL_DTYPE
145+
):
146+
rarg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(rarg)
147+
148+
return expr.op.as_expr(larg, rarg)
149+
150+
151+
class LowerDivRule(op_lowering.OpLoweringRule):
152+
@property
153+
def op(self) -> type[ops.ScalarOp]:
154+
return numeric_ops.DivOp
155+
156+
def lower(self, expr: expression.OpExpression) -> expression.Expression:
157+
assert isinstance(expr.op, numeric_ops.DivOp)
158+
159+
dividend = expr.children[0]
160+
divisor = expr.children[1]
161+
162+
if dividend.output_type == dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(
163+
divisor.output_type
164+
):
165+
# exact same as floordiv impl for timedelta
166+
numeric_result = ops.floordiv_op.as_expr(
167+
ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(dividend), divisor
168+
)
169+
int_result = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(numeric_result)
170+
return ops.AsTypeOp(to_type=dtypes.TIMEDELTA_DTYPE).as_expr(int_result)
171+
172+
if (
173+
dividend.output_type == dtypes.BOOL_DTYPE
174+
and divisor.output_type == dtypes.BOOL_DTYPE
175+
):
176+
int_result = expr.op.as_expr(
177+
ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(dividend),
178+
ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(divisor),
179+
)
180+
return ops.AsTypeOp(to_type=dtypes.BOOL_DTYPE).as_expr(int_result)
181+
182+
# polars divide doesn't like bools, convert to int always
183+
# convert numerics to float always
184+
if dividend.output_type == dtypes.BOOL_DTYPE:
185+
dividend = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(dividend)
186+
elif dividend.output_type in (dtypes.BIGNUMERIC_DTYPE, dtypes.NUMERIC_DTYPE):
187+
dividend = ops.AsTypeOp(to_type=dtypes.FLOAT_DTYPE).as_expr(dividend)
188+
if divisor.output_type == dtypes.BOOL_DTYPE:
189+
divisor = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(divisor)
190+
191+
return numeric_ops.div_op.as_expr(dividend, divisor)
192+
193+
40194
class LowerFloorDivRule(op_lowering.OpLoweringRule):
41195
@property
42196
def op(self) -> type[ops.ScalarOp]:
43197
return numeric_ops.FloorDivOp
44198

45199
def lower(self, expr: expression.OpExpression) -> expression.Expression:
200+
assert isinstance(expr.op, numeric_ops.FloorDivOp)
201+
46202
dividend = expr.children[0]
47203
divisor = expr.children[1]
48-
using_floats = (dividend.output_type == dtypes.FLOAT_DTYPE) or (
49-
divisor.output_type == dtypes.FLOAT_DTYPE
50-
)
51-
inf_or_zero = (
52-
expression.const(float("INF")) if using_floats else expression.const(0)
53-
)
54-
zero_result = ops.mul_op.as_expr(inf_or_zero, dividend)
55-
divisor_is_zero = ops.eq_op.as_expr(divisor, expression.const(0))
56-
return ops.where_op.as_expr(zero_result, divisor_is_zero, expr)
204+
205+
if (
206+
dividend.output_type == dtypes.TIMEDELTA_DTYPE
207+
and divisor.output_type == dtypes.TIMEDELTA_DTYPE
208+
):
209+
int_result = expr.op.as_expr(
210+
ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(dividend),
211+
ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(divisor),
212+
)
213+
return int_result
214+
if dividend.output_type == dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(
215+
divisor.output_type
216+
):
217+
# this is pretty fragile as zero will break it, and must fit back into int
218+
numeric_result = expr.op.as_expr(
219+
ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(dividend), divisor
220+
)
221+
int_result = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(numeric_result)
222+
return ops.AsTypeOp(to_type=dtypes.TIMEDELTA_DTYPE).as_expr(int_result)
223+
224+
if dividend.output_type == dtypes.BOOL_DTYPE:
225+
dividend = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(dividend)
226+
if divisor.output_type == dtypes.BOOL_DTYPE:
227+
divisor = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(divisor)
228+
229+
if expr.output_type != dtypes.FLOAT_DTYPE:
230+
# need to guard against zero divisor
231+
# multiply dividend in this case to propagate nulls
232+
return ops.where_op.as_expr(
233+
ops.mul_op.as_expr(dividend, expression.const(0)),
234+
ops.eq_op.as_expr(divisor, expression.const(0)),
235+
numeric_ops.floordiv_op.as_expr(dividend, divisor),
236+
)
237+
else:
238+
return expr.op.as_expr(dividend, divisor)
239+
240+
241+
class LowerModRule(op_lowering.OpLoweringRule):
242+
@property
243+
def op(self) -> type[ops.ScalarOp]:
244+
return numeric_ops.ModOp
245+
246+
def lower(self, expr: expression.OpExpression) -> expression.Expression:
247+
og_expr = expr
248+
assert isinstance(expr.op, numeric_ops.ModOp)
249+
larg, rarg = expr.children[0], expr.children[1]
250+
251+
if (
252+
larg.output_type == dtypes.TIMEDELTA_DTYPE
253+
and rarg.output_type == dtypes.TIMEDELTA_DTYPE
254+
):
255+
larg_int = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(larg)
256+
rarg_int = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(rarg)
257+
int_result = expr.op.as_expr(larg_int, rarg_int)
258+
w_zero_handling = ops.where_op.as_expr(
259+
int_result,
260+
ops.ne_op.as_expr(rarg_int, expression.const(0)),
261+
ops.mul_op.as_expr(rarg_int, expression.const(0)),
262+
)
263+
return ops.AsTypeOp(to_type=dtypes.TIMEDELTA_DTYPE).as_expr(w_zero_handling)
264+
265+
if larg.output_type == dtypes.BOOL_DTYPE:
266+
larg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(larg)
267+
if rarg.output_type == dtypes.BOOL_DTYPE:
268+
rarg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(rarg)
269+
270+
wo_bools = expr.op.as_expr(larg, rarg)
271+
272+
if og_expr.output_type == dtypes.INT_DTYPE:
273+
return ops.where_op.as_expr(
274+
wo_bools,
275+
ops.ne_op.as_expr(rarg, expression.const(0)),
276+
ops.mul_op.as_expr(rarg, expression.const(0)),
277+
)
278+
return wo_bools
57279

58280

59-
def _coerce_comparables(expr1: expression.Expression, expr2: expression.Expression):
281+
def _coerce_comparables(
282+
expr1: expression.Expression,
283+
expr2: expression.Expression,
284+
*,
285+
bools_only: bool = False
286+
):
287+
if bools_only:
288+
if (
289+
expr1.output_type != dtypes.BOOL_DTYPE
290+
and expr2.output_type != dtypes.BOOL_DTYPE
291+
):
292+
return expr1, expr2
60293

61294
target_type = dtypes.coerce_to_common(expr1.output_type, expr2.output_type)
62295
if expr1.output_type != target_type:
@@ -90,7 +323,12 @@ def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression):
90323

91324
POLARS_LOWERING_RULES = (
92325
*LOWER_COMPARISONS,
326+
LowerAddRule(),
327+
LowerSubRule(),
328+
LowerMulRule(),
329+
LowerDivRule(),
93330
LowerFloorDivRule(),
331+
LowerModRule(),
94332
)
95333

96334

0 commit comments

Comments
 (0)