Skip to content

Commit d9bc4a5

Browse files
feat: Can cast locally in hybrid engine (#1944)
1 parent c4c7fa5 commit d9bc4a5

File tree

10 files changed

+676
-7
lines changed

10 files changed

+676
-7
lines changed

bigframes/core/compile/polars/compiler.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@
3333
import bigframes.operations.aggregations as agg_ops
3434
import bigframes.operations.bool_ops as bool_ops
3535
import bigframes.operations.comparison_ops as comp_ops
36+
import bigframes.operations.datetime_ops as dt_ops
3637
import bigframes.operations.generic_ops as gen_ops
38+
import bigframes.operations.json_ops as json_ops
3739
import bigframes.operations.numeric_ops as num_ops
3840
import bigframes.operations.string_ops as string_ops
3941

@@ -280,6 +282,30 @@ def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr:
280282
assert isinstance(op, string_ops.StrConcatOp)
281283
return pl.concat_str(l_input, r_input)
282284

285+
@compile_op.register(dt_ops.StrftimeOp)
286+
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
287+
assert isinstance(op, dt_ops.StrftimeOp)
288+
return input.dt.strftime(op.date_format)
289+
290+
@compile_op.register(dt_ops.ParseDatetimeOp)
291+
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
292+
assert isinstance(op, dt_ops.ParseDatetimeOp)
293+
return input.str.to_datetime(
294+
time_unit="us", time_zone=None, ambiguous="earliest"
295+
)
296+
297+
@compile_op.register(dt_ops.ParseTimestampOp)
298+
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
299+
assert isinstance(op, dt_ops.ParseTimestampOp)
300+
return input.str.to_datetime(
301+
time_unit="us", time_zone="UTC", ambiguous="earliest"
302+
)
303+
304+
@compile_op.register(json_ops.JSONDecode)
305+
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
306+
assert isinstance(op, json_ops.JSONDecode)
307+
return input.str.json_decode(_DTYPE_MAPPING[op.to_type])
308+
283309
@dataclasses.dataclass(frozen=True)
284310
class PolarsAggregateCompiler:
285311
scalar_compiler = PolarsExpressionCompiler()

bigframes/core/compile/polars/lowering.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from bigframes import dtypes
1818
from bigframes.core import bigframe_node, expression
1919
from bigframes.core.rewrite import op_lowering
20-
from bigframes.operations import comparison_ops, numeric_ops
20+
from bigframes.operations import comparison_ops, datetime_ops, json_ops, numeric_ops
2121
import bigframes.operations as ops
2222

2323
# TODO: Would be more precise to actually have separate op set for polars ops (where they diverge from the original ops)
@@ -278,6 +278,16 @@ def lower(self, expr: expression.OpExpression) -> expression.Expression:
278278
return wo_bools
279279

280280

281+
class LowerAsTypeRule(op_lowering.OpLoweringRule):
282+
@property
283+
def op(self) -> type[ops.ScalarOp]:
284+
return ops.AsTypeOp
285+
286+
def lower(self, expr: expression.OpExpression) -> expression.Expression:
287+
assert isinstance(expr.op, ops.AsTypeOp)
288+
return _lower_cast(expr.op, expr.inputs[0])
289+
290+
281291
def _coerce_comparables(
282292
expr1: expression.Expression,
283293
expr2: expression.Expression,
@@ -299,12 +309,57 @@ def _coerce_comparables(
299309
return expr1, expr2
300310

301311

302-
# TODO: Need to handle bool->string cast to get capitalization correct
303312
def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression):
313+
if arg.output_type == cast_op.to_type:
314+
return arg
315+
316+
if arg.output_type == dtypes.JSON_DTYPE:
317+
return json_ops.JSONDecode(cast_op.to_type).as_expr(arg)
318+
if (
319+
arg.output_type == dtypes.STRING_DTYPE
320+
and cast_op.to_type == dtypes.DATETIME_DTYPE
321+
):
322+
return datetime_ops.ParseDatetimeOp().as_expr(arg)
323+
if (
324+
arg.output_type == dtypes.STRING_DTYPE
325+
and cast_op.to_type == dtypes.TIMESTAMP_DTYPE
326+
):
327+
return datetime_ops.ParseTimestampOp().as_expr(arg)
328+
# date -> string casting
329+
if (
330+
arg.output_type == dtypes.DATETIME_DTYPE
331+
and cast_op.to_type == dtypes.STRING_DTYPE
332+
):
333+
return datetime_ops.StrftimeOp("%Y-%m-%d %H:%M:%S").as_expr(arg)
334+
if arg.output_type == dtypes.TIME_DTYPE and cast_op.to_type == dtypes.STRING_DTYPE:
335+
return datetime_ops.StrftimeOp("%H:%M:%S.%6f").as_expr(arg)
336+
if (
337+
arg.output_type == dtypes.TIMESTAMP_DTYPE
338+
and cast_op.to_type == dtypes.STRING_DTYPE
339+
):
340+
return datetime_ops.StrftimeOp("%Y-%m-%d %H:%M:%S%.6f%:::z").as_expr(arg)
341+
if arg.output_type == dtypes.BOOL_DTYPE and cast_op.to_type == dtypes.STRING_DTYPE:
342+
# bool -> decimal needs two-step cast
343+
new_arg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(arg)
344+
is_true_cond = ops.eq_op.as_expr(arg, expression.const(True))
345+
is_false_cond = ops.eq_op.as_expr(arg, expression.const(False))
346+
return ops.CaseWhenOp().as_expr(
347+
is_true_cond,
348+
expression.const("True"),
349+
is_false_cond,
350+
expression.const("False"),
351+
)
304352
if arg.output_type == dtypes.BOOL_DTYPE and dtypes.is_numeric(cast_op.to_type):
305353
# bool -> decimal needs two-step cast
306354
new_arg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(arg)
307355
return cast_op.as_expr(new_arg)
356+
if arg.output_type == dtypes.TIME_DTYPE and dtypes.is_numeric(cast_op.to_type):
357+
# polars cast gives nanoseconds, so convert to microseconds
358+
return numeric_ops.floordiv_op.as_expr(
359+
cast_op.as_expr(arg), expression.const(1000)
360+
)
361+
if dtypes.is_numeric(arg.output_type) and cast_op.to_type == dtypes.TIME_DTYPE:
362+
return cast_op.as_expr(ops.mul_op.as_expr(expression.const(1000), arg))
308363
return cast_op.as_expr(arg)
309364

310365

@@ -329,6 +384,7 @@ def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression):
329384
LowerDivRule(),
330385
LowerFloorDivRule(),
331386
LowerModRule(),
387+
LowerAsTypeRule(),
332388
)
333389

334390

bigframes/operations/datetime_ops.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,28 @@
3939
time_op = TimeOp()
4040

4141

42+
@dataclasses.dataclass(frozen=True)
43+
class ParseDatetimeOp(base_ops.UnaryOp):
44+
# TODO: Support strict format
45+
name: typing.ClassVar[str] = "parse_datetime"
46+
47+
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
48+
if input_types[0] != dtypes.STRING_DTYPE:
49+
raise TypeError("expected string input")
50+
return pd.ArrowDtype(pa.timestamp("us", tz=None))
51+
52+
53+
@dataclasses.dataclass(frozen=True)
54+
class ParseTimestampOp(base_ops.UnaryOp):
55+
# TODO: Support strict format
56+
name: typing.ClassVar[str] = "parse_timestamp"
57+
58+
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
59+
if input_types[0] != dtypes.STRING_DTYPE:
60+
raise TypeError("expected string input")
61+
return pd.ArrowDtype(pa.timestamp("us", tz="UTC"))
62+
63+
4264
@dataclasses.dataclass(frozen=True)
4365
class ToDatetimeOp(base_ops.UnaryOp):
4466
name: typing.ClassVar[str] = "to_datetime"

0 commit comments

Comments
 (0)