Skip to content

Commit 63205f2

Browse files
refactor: Refactor polars scalar op compiler (#1807)
1 parent b586746 commit 63205f2

File tree

13 files changed

+317
-166
lines changed

13 files changed

+317
-166
lines changed

bigframes/core/compile/polars/compiler.py

Lines changed: 137 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@
2929
import bigframes.dtypes
3030
import bigframes.operations as ops
3131
import bigframes.operations.aggregations as agg_ops
32+
import bigframes.operations.bool_ops as bool_ops
33+
import bigframes.operations.comparison_ops as comp_ops
34+
import bigframes.operations.generic_ops as gen_ops
35+
import bigframes.operations.numeric_ops as num_ops
3236

3337
polars_installed = True
3438
if TYPE_CHECKING:
@@ -123,84 +127,146 @@ def _(
123127
self,
124128
expression: ex.OpExpression,
125129
) -> pl.Expr:
126-
# TODO: Complete the implementation, convert to hash dispatch
130+
# TODO: Complete the implementation
127131
op = expression.op
128132
args = tuple(map(self.compile_expression, expression.inputs))
129-
if isinstance(op, ops.invert_op.__class__):
130-
return ~args[0]
131-
if isinstance(op, ops.and_op.__class__):
132-
return args[0] & args[1]
133-
if isinstance(op, ops.or_op.__class__):
134-
return args[0] | args[1]
135-
if isinstance(op, ops.add_op.__class__):
136-
return args[0] + args[1]
137-
if isinstance(op, ops.sub_op.__class__):
138-
return args[0] - args[1]
139-
if isinstance(op, ops.mul_op.__class__):
140-
return args[0] * args[1]
141-
if isinstance(op, ops.div_op.__class__):
142-
return args[0] / args[1]
143-
if isinstance(op, ops.floordiv_op.__class__):
144-
# TODO: Handle int // 0
145-
return args[0] // args[1]
146-
if isinstance(op, (ops.pow_op.__class__, ops.unsafe_pow_op.__class__)):
147-
return args[0] ** args[1]
148-
if isinstance(op, ops.abs_op.__class__):
149-
return args[0].abs()
150-
if isinstance(op, ops.neg_op.__class__):
151-
return args[0].neg()
152-
if isinstance(op, ops.pos_op.__class__):
153-
return args[0]
154-
if isinstance(op, ops.ge_op.__class__):
155-
return args[0] >= args[1]
156-
if isinstance(op, ops.gt_op.__class__):
157-
return args[0] > args[1]
158-
if isinstance(op, ops.le_op.__class__):
159-
return args[0] <= args[1]
160-
if isinstance(op, ops.lt_op.__class__):
161-
return args[0] < args[1]
162-
if isinstance(op, ops.eq_op.__class__):
163-
return args[0].eq(args[1])
164-
if isinstance(op, ops.eq_null_match_op.__class__):
165-
return args[0].eq_missing(args[1])
166-
if isinstance(op, ops.ne_op.__class__):
167-
return args[0].ne(args[1])
168-
if isinstance(op, ops.IsInOp):
169-
# TODO: Filter out types that can't be coerced to right type
170-
if op.match_nulls or not any(map(pd.isna, op.values)):
171-
# newer polars version have nulls_equal arg
172-
return args[0].is_in(op.values)
173-
else:
174-
return args[0].is_in(op.values) or args[0].is_null()
175-
if isinstance(op, ops.mod_op.__class__):
176-
return args[0] % args[1]
177-
if isinstance(op, ops.coalesce_op.__class__):
178-
return pl.coalesce(*args)
179-
if isinstance(op, ops.fillna_op.__class__):
180-
return pl.coalesce(*args)
181-
if isinstance(op, ops.isnull_op.__class__):
182-
return args[0].is_null()
183-
if isinstance(op, ops.notnull_op.__class__):
184-
return args[0].is_not_null()
185-
if isinstance(op, ops.CaseWhenOp):
186-
expr = pl.when(args[0]).then(args[1])
187-
for pred, result in zip(args[2::2], args[3::2]):
188-
expr = expr.when(pred).then(result) # type: ignore
189-
return expr
190-
if isinstance(op, ops.where_op.__class__):
191-
original, condition, otherwise = args
192-
return pl.when(condition).then(original).otherwise(otherwise)
193-
if isinstance(op, ops.AsTypeOp):
194-
return self.astype(args[0], op.to_type, safe=op.safe)
133+
return self.compile_op(op, *args)
195134

135+
@functools.singledispatchmethod
136+
def compile_op(self, op: ops.ScalarOp, *args: pl.Expr) -> pl.Expr:
196137
raise NotImplementedError(f"Polars compiler hasn't implemented {op}")
197138

198-
def astype(
199-
self, col: pl.Expr, dtype: bigframes.dtypes.Dtype, safe: bool
139+
@compile_op.register(gen_ops.InvertOp)
140+
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
141+
return ~input
142+
143+
@compile_op.register(num_ops.AbsOp)
144+
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
145+
return input.abs()
146+
147+
@compile_op.register(num_ops.PosOp)
148+
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
149+
return input.__pos__()
150+
151+
@compile_op.register(num_ops.NegOp)
152+
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
153+
return input.__neg__()
154+
155+
@compile_op.register(bool_ops.AndOp)
156+
def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr:
157+
return l_input & r_input
158+
159+
@compile_op.register(bool_ops.OrOp)
160+
def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr:
161+
return l_input | r_input
162+
163+
@compile_op.register(num_ops.AddOp)
164+
def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr:
165+
return l_input + r_input
166+
167+
@compile_op.register(num_ops.SubOp)
168+
def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr:
169+
return l_input - r_input
170+
171+
@compile_op.register(num_ops.MulOp)
172+
def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr:
173+
return l_input * r_input
174+
175+
@compile_op.register(num_ops.DivOp)
176+
def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr:
177+
return l_input / r_input
178+
179+
@compile_op.register(num_ops.FloorDivOp)
180+
def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr:
181+
return l_input // r_input
182+
183+
@compile_op.register(num_ops.FloorDivOp)
184+
def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr:
185+
return l_input // r_input
186+
187+
@compile_op.register(num_ops.ModOp)
188+
def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr:
189+
return l_input % r_input
190+
191+
@compile_op.register(num_ops.PowOp)
192+
@compile_op.register(num_ops.UnsafePowOp)
193+
def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr:
194+
return l_input**r_input
195+
196+
@compile_op.register(comp_ops.EqOp)
197+
def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr:
198+
return l_input.eq(r_input)
199+
200+
@compile_op.register(comp_ops.EqNullsMatchOp)
201+
def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr:
202+
return l_input.eq_missing(r_input)
203+
204+
@compile_op.register(comp_ops.NeOp)
205+
def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr:
206+
return l_input.ne(r_input)
207+
208+
@compile_op.register(comp_ops.GtOp)
209+
def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr:
210+
return l_input > r_input
211+
212+
@compile_op.register(comp_ops.GeOp)
213+
def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr:
214+
return l_input >= r_input
215+
216+
@compile_op.register(comp_ops.LtOp)
217+
def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr:
218+
return l_input < r_input
219+
220+
@compile_op.register(comp_ops.LeOp)
221+
def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr:
222+
return l_input <= r_input
223+
224+
@compile_op.register(gen_ops.IsInOp)
225+
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
226+
# TODO: Filter out types that can't be coerced to right type
227+
assert isinstance(op, gen_ops.IsInOp)
228+
if op.match_nulls or not any(map(pd.isna, op.values)):
229+
# newer polars version have nulls_equal arg
230+
return input.is_in(op.values)
231+
else:
232+
return input.is_in(op.values) or input.is_null()
233+
234+
@compile_op.register(gen_ops.IsNullOp)
235+
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
236+
return input.is_null()
237+
238+
@compile_op.register(gen_ops.NotNullOp)
239+
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
240+
return input.is_not_null()
241+
242+
@compile_op.register(gen_ops.FillNaOp)
243+
@compile_op.register(gen_ops.CoalesceOp)
244+
def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr:
245+
return pl.coalesce(l_input, r_input)
246+
247+
@compile_op.register(gen_ops.CaseWhenOp)
248+
def _(self, op: ops.ScalarOp, *inputs: pl.Expr) -> pl.Expr:
249+
expr = pl.when(inputs[0]).then(inputs[1])
250+
for pred, result in zip(inputs[2::2], inputs[3::2]):
251+
expr = expr.when(pred).then(result) # type: ignore
252+
return expr
253+
254+
@compile_op.register(gen_ops.WhereOp)
255+
def _(
256+
self,
257+
op: ops.ScalarOp,
258+
original: pl.Expr,
259+
condition: pl.Expr,
260+
otherwise: pl.Expr,
200261
) -> pl.Expr:
262+
return pl.when(condition).then(original).otherwise(otherwise)
263+
264+
@compile_op.register(gen_ops.AsTypeOp)
265+
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
266+
assert isinstance(op, gen_ops.AsTypeOp)
201267
# TODO: Polars casting works differently, need to lower instead to specific conversion ops.
202-
# eg. We want "True" instead of "true" for bool to string.
203-
return col.cast(_DTYPE_MAPPING[dtype], strict=not safe)
268+
# eg. We want "True" instead of "true" for bool to strin
269+
return input.cast(_DTYPE_MAPPING[op.to_type], strict=not op.safe)
204270

205271
@dataclasses.dataclass(frozen=True)
206272
class PolarsAggregateCompiler:

bigframes/operations/base_ops.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,9 @@ def _convert_expr_input(
180180

181181

182182
# Operation Factories
183-
def create_unary_op(name: str, type_signature: op_typing.UnaryTypeSignature) -> UnaryOp:
183+
def create_unary_op(
184+
name: str, type_signature: op_typing.UnaryTypeSignature
185+
) -> type[UnaryOp]:
184186
return dataclasses.make_dataclass(
185187
name,
186188
[
@@ -189,12 +191,12 @@ def create_unary_op(name: str, type_signature: op_typing.UnaryTypeSignature) ->
189191
],
190192
bases=(UnaryOp,),
191193
frozen=True,
192-
)()
194+
)
193195

194196

195197
def create_binary_op(
196198
name: str, type_signature: op_typing.BinaryTypeSignature
197-
) -> BinaryOp:
199+
) -> type[BinaryOp]:
198200
return dataclasses.make_dataclass(
199201
name,
200202
[
@@ -203,4 +205,4 @@ def create_binary_op(
203205
],
204206
bases=(BinaryOp,),
205207
frozen=True,
206-
)()
208+
)

bigframes/operations/blob_ops.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@
1919
from bigframes.operations import base_ops
2020
import bigframes.operations.type as op_typing
2121

22-
obj_fetch_metadata_op = base_ops.create_unary_op(
22+
ObjFetchMetadataOp = base_ops.create_unary_op(
2323
name="obj_fetch_metadata", type_signature=op_typing.BLOB_TRANSFORM
2424
)
25+
obj_fetch_metadata_op = ObjFetchMetadataOp()
2526

2627

2728
@dataclasses.dataclass(frozen=True)

bigframes/operations/bool_ops.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,11 @@
1616
from bigframes.operations import base_ops
1717
import bigframes.operations.type as op_typing
1818

19-
and_op = base_ops.create_binary_op(name="and", type_signature=op_typing.LOGICAL)
19+
AndOp = base_ops.create_binary_op(name="and", type_signature=op_typing.LOGICAL)
20+
and_op = AndOp()
2021

21-
or_op = base_ops.create_binary_op(name="or", type_signature=op_typing.LOGICAL)
22+
OrOp = base_ops.create_binary_op(name="or", type_signature=op_typing.LOGICAL)
23+
or_op = OrOp()
2224

23-
xor_op = base_ops.create_binary_op(name="xor", type_signature=op_typing.LOGICAL)
25+
XorOp = base_ops.create_binary_op(name="xor", type_signature=op_typing.LOGICAL)
26+
xor_op = XorOp()

bigframes/operations/comparison_ops.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,25 @@
1616
from bigframes.operations import base_ops
1717
import bigframes.operations.type as op_typing
1818

19-
eq_op = base_ops.create_binary_op(name="eq", type_signature=op_typing.COMPARISON)
19+
EqOp = base_ops.create_binary_op(name="eq", type_signature=op_typing.COMPARISON)
20+
eq_op = EqOp()
2021

21-
eq_null_match_op = base_ops.create_binary_op(
22+
EqNullsMatchOp = base_ops.create_binary_op(
2223
name="eq_nulls_match", type_signature=op_typing.COMPARISON
2324
)
25+
eq_null_match_op = EqNullsMatchOp()
2426

25-
ne_op = base_ops.create_binary_op(name="ne", type_signature=op_typing.COMPARISON)
27+
NeOp = base_ops.create_binary_op(name="ne", type_signature=op_typing.COMPARISON)
28+
ne_op = NeOp()
2629

27-
lt_op = base_ops.create_binary_op(name="lt", type_signature=op_typing.COMPARISON)
30+
LtOp = base_ops.create_binary_op(name="lt", type_signature=op_typing.COMPARISON)
31+
lt_op = LtOp()
2832

29-
gt_op = base_ops.create_binary_op(name="gt", type_signature=op_typing.COMPARISON)
33+
GtOp = base_ops.create_binary_op(name="gt", type_signature=op_typing.COMPARISON)
34+
gt_op = GtOp()
3035

31-
le_op = base_ops.create_binary_op(name="le", type_signature=op_typing.COMPARISON)
36+
LeOp = base_ops.create_binary_op(name="le", type_signature=op_typing.COMPARISON)
37+
le_op = LeOp()
3238

33-
ge_op = base_ops.create_binary_op(name="ge", type_signature=op_typing.COMPARISON)
39+
GeOp = base_ops.create_binary_op(name="ge", type_signature=op_typing.COMPARISON)
40+
ge_op = GeOp()

bigframes/operations/date_ops.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,49 +19,58 @@
1919
from bigframes.operations import base_ops
2020
import bigframes.operations.type as op_typing
2121

22-
day_op = base_ops.create_unary_op(
22+
DayOp = base_ops.create_unary_op(
2323
name="day",
2424
type_signature=op_typing.DATELIKE_ACCESSOR,
2525
)
26+
day_op = DayOp()
2627

27-
month_op = base_ops.create_unary_op(
28+
MonthOp = base_ops.create_unary_op(
2829
name="month",
2930
type_signature=op_typing.DATELIKE_ACCESSOR,
3031
)
32+
month_op = MonthOp()
3133

32-
year_op = base_ops.create_unary_op(
34+
YearOp = base_ops.create_unary_op(
3335
name="year",
3436
type_signature=op_typing.DATELIKE_ACCESSOR,
3537
)
38+
year_op = YearOp()
3639

37-
iso_day_op = base_ops.create_unary_op(
40+
IsoDayOp = base_ops.create_unary_op(
3841
name="iso_day", type_signature=op_typing.DATELIKE_ACCESSOR
3942
)
43+
iso_day_op = IsoDayOp()
4044

41-
iso_week_op = base_ops.create_unary_op(
45+
IsoWeekOp = base_ops.create_unary_op(
4246
name="iso_weeek",
4347
type_signature=op_typing.DATELIKE_ACCESSOR,
4448
)
49+
iso_week_op = IsoWeekOp()
4550

46-
iso_year_op = base_ops.create_unary_op(
51+
IsoYearOp = base_ops.create_unary_op(
4752
name="iso_year",
4853
type_signature=op_typing.DATELIKE_ACCESSOR,
4954
)
55+
iso_year_op = IsoYearOp()
5056

51-
dayofweek_op = base_ops.create_unary_op(
57+
DayOfWeekOp = base_ops.create_unary_op(
5258
name="dayofweek",
5359
type_signature=op_typing.DATELIKE_ACCESSOR,
5460
)
61+
dayofweek_op = DayOfWeekOp()
5562

56-
dayofyear_op = base_ops.create_unary_op(
63+
DayOfYearOp = base_ops.create_unary_op(
5764
name="dayofyear",
5865
type_signature=op_typing.DATELIKE_ACCESSOR,
5966
)
67+
dayofyear_op = DayOfYearOp()
6068

61-
quarter_op = base_ops.create_unary_op(
69+
QuarterOp = base_ops.create_unary_op(
6270
name="quarter",
6371
type_signature=op_typing.DATELIKE_ACCESSOR,
6472
)
73+
quarter_op = QuarterOp()
6574

6675

6776
@dataclasses.dataclass(frozen=True)

0 commit comments

Comments
 (0)