Skip to content

Commit f7f686c

Browse files
feat: Add where, coalesce, fillna, casewhen, invert local impl (#1976)
1 parent 1b25c22 commit f7f686c

File tree

8 files changed

+187
-3
lines changed

8 files changed

+187
-3
lines changed

bigframes/core/compile/polars/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def compile_op(self, op: ops.ScalarOp, *args: pl.Expr) -> pl.Expr:
168168

169169
@compile_op.register(gen_ops.InvertOp)
170170
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
171-
return ~input
171+
return input.not_()
172172

173173
@compile_op.register(num_ops.AbsOp)
174174
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:

bigframes/core/compile/polars/lowering.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,18 @@
1414

1515
import dataclasses
1616

17+
import numpy as np
18+
1719
from bigframes import dtypes
1820
from bigframes.core import bigframe_node, expression
1921
from bigframes.core.rewrite import op_lowering
20-
from bigframes.operations import comparison_ops, datetime_ops, json_ops, numeric_ops
22+
from bigframes.operations import (
23+
comparison_ops,
24+
datetime_ops,
25+
generic_ops,
26+
json_ops,
27+
numeric_ops,
28+
)
2129
import bigframes.operations as ops
2230

2331
# TODO: Would be more precise to actually have separate op set for polars ops (where they diverge from the original ops)
@@ -288,6 +296,26 @@ def lower(self, expr: expression.OpExpression) -> expression.Expression:
288296
return _lower_cast(expr.op, expr.inputs[0])
289297

290298

299+
def invert_bytes(byte_string):
300+
inverted_bytes = ~np.frombuffer(byte_string, dtype=np.uint8)
301+
return inverted_bytes.tobytes()
302+
303+
304+
class LowerInvertOp(op_lowering.OpLoweringRule):
305+
@property
306+
def op(self) -> type[ops.ScalarOp]:
307+
return generic_ops.InvertOp
308+
309+
def lower(self, expr: expression.OpExpression) -> expression.Expression:
310+
assert isinstance(expr.op, generic_ops.InvertOp)
311+
arg = expr.children[0]
312+
if arg.output_type == dtypes.BYTES_DTYPE:
313+
return generic_ops.PyUdfOp(invert_bytes, dtypes.BYTES_DTYPE).as_expr(
314+
expr.inputs[0]
315+
)
316+
return expr
317+
318+
291319
def _coerce_comparables(
292320
expr1: expression.Expression,
293321
expr2: expression.Expression,
@@ -385,6 +413,7 @@ def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression):
385413
LowerFloorDivRule(),
386414
LowerModRule(),
387415
LowerAsTypeRule(),
416+
LowerInvertOp(),
388417
)
389418

390419

bigframes/core/compile/polars/operations/generic_ops.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,14 @@ def isnull_op_impl(
4545
input: pl.Expr,
4646
) -> pl.Expr:
4747
return input.is_null()
48+
49+
50+
@polars_compiler.register_op(generic_ops.PyUdfOp)
51+
def py_udf_op_impl(
52+
compiler: polars_compiler.PolarsExpressionCompiler,
53+
op: generic_ops.PyUdfOp, # type: ignore
54+
input: pl.Expr,
55+
) -> pl.Expr:
56+
return input.map_elements(
57+
op.fn, return_dtype=polars_compiler._DTYPE_MAPPING[op._output_type]
58+
)

bigframes/operations/generic_ops.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,3 +446,15 @@ class SqlScalarOp(base_ops.NaryOp):
446446

447447
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
448448
return self._output_type
449+
450+
451+
@dataclasses.dataclass(frozen=True)
452+
class PyUdfOp(base_ops.NaryOp):
453+
"""Represents a local UDF."""
454+
455+
name: typing.ClassVar[str] = "py_udf"
456+
fn: typing.Callable
457+
_output_type: dtypes.ExpressionType
458+
459+
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
460+
return self._output_type

bigframes/session/polars_executor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@
5858
numeric_ops.FloorDivOp,
5959
numeric_ops.ModOp,
6060
generic_ops.AsTypeOp,
61+
generic_ops.WhereOp,
62+
generic_ops.CoalesceOp,
63+
generic_ops.FillNaOp,
64+
generic_ops.CaseWhenOp,
65+
generic_ops.InvertOp,
6166
)
6267
_COMPATIBLE_AGG_OPS = (
6368
agg_ops.SizeOp,

tests/system/small/engines/test_generic_ops.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def test_engines_astype_int(scalars_array_value: array_value.ArrayValue, engine)
5959
ops.AsTypeOp(to_type=bigframes.dtypes.INT_DTYPE),
6060
excluded_cols=["string_col"],
6161
)
62+
6263
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
6364

6465

@@ -73,6 +74,7 @@ def test_engines_astype_string_int(scalars_array_value: array_value.ArrayValue,
7374
for val in vals
7475
]
7576
)
77+
7678
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
7779

7880

@@ -83,6 +85,7 @@ def test_engines_astype_float(scalars_array_value: array_value.ArrayValue, engin
8385
ops.AsTypeOp(to_type=bigframes.dtypes.FLOAT_DTYPE),
8486
excluded_cols=["string_col"],
8587
)
88+
8689
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
8790

8891

@@ -99,6 +102,7 @@ def test_engines_astype_string_float(
99102
for val in vals
100103
]
101104
)
105+
102106
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
103107

104108

@@ -107,6 +111,7 @@ def test_engines_astype_bool(scalars_array_value: array_value.ArrayValue, engine
107111
arr = apply_op(
108112
scalars_array_value, ops.AsTypeOp(to_type=bigframes.dtypes.BOOL_DTYPE)
109113
)
114+
110115
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
111116

112117

@@ -118,6 +123,7 @@ def test_engines_astype_string(scalars_array_value: array_value.ArrayValue, engi
118123
ops.AsTypeOp(to_type=bigframes.dtypes.STRING_DTYPE),
119124
excluded_cols=["float64_col"],
120125
)
126+
121127
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
122128

123129

@@ -128,6 +134,7 @@ def test_engines_astype_numeric(scalars_array_value: array_value.ArrayValue, eng
128134
ops.AsTypeOp(to_type=bigframes.dtypes.NUMERIC_DTYPE),
129135
excluded_cols=["string_col"],
130136
)
137+
131138
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
132139

133140

@@ -144,6 +151,7 @@ def test_engines_astype_string_numeric(
144151
for val in vals
145152
]
146153
)
154+
147155
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
148156

149157

@@ -154,6 +162,7 @@ def test_engines_astype_date(scalars_array_value: array_value.ArrayValue, engine
154162
ops.AsTypeOp(to_type=bigframes.dtypes.DATE_DTYPE),
155163
excluded_cols=["string_col"],
156164
)
165+
157166
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
158167

159168

@@ -170,6 +179,7 @@ def test_engines_astype_string_date(
170179
for val in vals
171180
]
172181
)
182+
173183
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
174184

175185

@@ -180,6 +190,7 @@ def test_engines_astype_datetime(scalars_array_value: array_value.ArrayValue, en
180190
ops.AsTypeOp(to_type=bigframes.dtypes.DATETIME_DTYPE),
181191
excluded_cols=["string_col"],
182192
)
193+
183194
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
184195

185196

@@ -196,6 +207,7 @@ def test_engines_astype_string_datetime(
196207
for val in vals
197208
]
198209
)
210+
199211
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
200212

201213

@@ -206,6 +218,7 @@ def test_engines_astype_timestamp(scalars_array_value: array_value.ArrayValue, e
206218
ops.AsTypeOp(to_type=bigframes.dtypes.TIMESTAMP_DTYPE),
207219
excluded_cols=["string_col"],
208220
)
221+
209222
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
210223

211224

@@ -226,6 +239,7 @@ def test_engines_astype_string_timestamp(
226239
for val in vals
227240
]
228241
)
242+
229243
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
230244

231245

@@ -236,6 +250,7 @@ def test_engines_astype_time(scalars_array_value: array_value.ArrayValue, engine
236250
ops.AsTypeOp(to_type=bigframes.dtypes.TIME_DTYPE),
237251
excluded_cols=["string_col", "int64_col", "int64_too"],
238252
)
253+
239254
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
240255

241256

@@ -256,6 +271,7 @@ def test_engines_astype_from_json(scalars_array_value: array_value.ArrayValue, e
256271
),
257272
]
258273
arr, _ = scalars_array_value.compute_values(exprs)
274+
259275
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
260276

261277

@@ -265,4 +281,112 @@ def test_engines_astype_timedelta(scalars_array_value: array_value.ArrayValue, e
265281
scalars_array_value,
266282
ops.AsTypeOp(to_type=bigframes.dtypes.TIMEDELTA_DTYPE),
267283
)
284+
285+
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
286+
287+
288+
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
289+
def test_engines_where_op(scalars_array_value: array_value.ArrayValue, engine):
290+
arr, _ = scalars_array_value.compute_values(
291+
[
292+
ops.where_op.as_expr(
293+
expression.deref("int64_col"),
294+
expression.deref("bool_col"),
295+
expression.deref("float64_col"),
296+
)
297+
]
298+
)
299+
300+
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
301+
302+
303+
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
304+
def test_engines_coalesce_op(scalars_array_value: array_value.ArrayValue, engine):
305+
arr, _ = scalars_array_value.compute_values(
306+
[
307+
ops.coalesce_op.as_expr(
308+
expression.deref("int64_col"),
309+
expression.deref("float64_col"),
310+
)
311+
]
312+
)
313+
314+
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
315+
316+
317+
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
318+
def test_engines_fillna_op(scalars_array_value: array_value.ArrayValue, engine):
319+
arr, _ = scalars_array_value.compute_values(
320+
[
321+
ops.fillna_op.as_expr(
322+
expression.deref("int64_col"),
323+
expression.deref("float64_col"),
324+
)
325+
]
326+
)
327+
328+
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
329+
330+
331+
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
332+
def test_engines_casewhen_op_single_case(
333+
scalars_array_value: array_value.ArrayValue, engine
334+
):
335+
arr, _ = scalars_array_value.compute_values(
336+
[
337+
ops.case_when_op.as_expr(
338+
expression.deref("bool_col"),
339+
expression.deref("int64_col"),
340+
)
341+
]
342+
)
343+
344+
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
345+
346+
347+
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
348+
def test_engines_casewhen_op_double_case(
349+
scalars_array_value: array_value.ArrayValue, engine
350+
):
351+
arr, _ = scalars_array_value.compute_values(
352+
[
353+
ops.case_when_op.as_expr(
354+
ops.gt_op.as_expr(expression.deref("int64_col"), expression.const(3)),
355+
expression.deref("int64_col"),
356+
ops.lt_op.as_expr(expression.deref("int64_col"), expression.const(-3)),
357+
expression.deref("int64_too"),
358+
)
359+
]
360+
)
361+
362+
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
363+
364+
365+
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
366+
def test_engines_isnull_op(scalars_array_value: array_value.ArrayValue, engine):
367+
arr, _ = scalars_array_value.compute_values(
368+
[ops.isnull_op.as_expr(expression.deref("string_col"))]
369+
)
370+
371+
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
372+
373+
374+
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
375+
def test_engines_notnull_op(scalars_array_value: array_value.ArrayValue, engine):
376+
arr, _ = scalars_array_value.compute_values(
377+
[ops.notnull_op.as_expr(expression.deref("string_col"))]
378+
)
379+
380+
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
381+
382+
383+
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
384+
def test_engines_invert_op(scalars_array_value: array_value.ArrayValue, engine):
385+
arr, _ = scalars_array_value.compute_values(
386+
[
387+
ops.invert_op.as_expr(expression.deref("bytes_col")),
388+
ops.invert_op.as_expr(expression.deref("bool_col")),
389+
]
390+
)
391+
268392
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)

third_party/bigframes_vendored/ibis/expr/operations/numeric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ class Tan(TrigonometricUnary):
326326
class BitwiseNot(Unary):
327327
"""Bitwise NOT operation."""
328328

329-
arg: Integer
329+
arg: Value[dt.Integer | dt.Binary]
330330

331331
dtype = rlz.numeric_like("args", operator.invert)
332332

third_party/bigframes_vendored/ibis/expr/types/binary.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ def hashbytes(
3232
"""
3333
return ops.HashBytes(self, how).to_expr()
3434

35+
def __invert__(self) -> BinaryValue:
36+
return ops.BitwiseNot(self).to_expr()
37+
3538

3639
@public
3740
class BinaryScalar(Scalar, BinaryValue):

0 commit comments

Comments
 (0)