17
17
from bigframes import dtypes
18
18
from bigframes .core import bigframe_node , expression
19
19
from bigframes .core .rewrite import op_lowering
20
- from bigframes .operations import comparison_ops , numeric_ops
20
+ from bigframes .operations import comparison_ops , datetime_ops , json_ops , numeric_ops
21
21
import bigframes .operations as ops
22
22
23
23
# TODO: Would be more precise to actually have separate op set for polars ops (where they diverge from the original ops)
@@ -278,6 +278,16 @@ def lower(self, expr: expression.OpExpression) -> expression.Expression:
278
278
return wo_bools
279
279
280
280
281
+ class LowerAsTypeRule (op_lowering .OpLoweringRule ):
282
+ @property
283
+ def op (self ) -> type [ops .ScalarOp ]:
284
+ return ops .AsTypeOp
285
+
286
+ def lower (self , expr : expression .OpExpression ) -> expression .Expression :
287
+ assert isinstance (expr .op , ops .AsTypeOp )
288
+ return _lower_cast (expr .op , expr .inputs [0 ])
289
+
290
+
281
291
def _coerce_comparables (
282
292
expr1 : expression .Expression ,
283
293
expr2 : expression .Expression ,
@@ -299,12 +309,57 @@ def _coerce_comparables(
299
309
return expr1 , expr2
300
310
301
311
302
- # TODO: Need to handle bool->string cast to get capitalization correct
303
312
def _lower_cast (cast_op : ops .AsTypeOp , arg : expression .Expression ):
313
+ if arg .output_type == cast_op .to_type :
314
+ return arg
315
+
316
+ if arg .output_type == dtypes .JSON_DTYPE :
317
+ return json_ops .JSONDecode (cast_op .to_type ).as_expr (arg )
318
+ if (
319
+ arg .output_type == dtypes .STRING_DTYPE
320
+ and cast_op .to_type == dtypes .DATETIME_DTYPE
321
+ ):
322
+ return datetime_ops .ParseDatetimeOp ().as_expr (arg )
323
+ if (
324
+ arg .output_type == dtypes .STRING_DTYPE
325
+ and cast_op .to_type == dtypes .TIMESTAMP_DTYPE
326
+ ):
327
+ return datetime_ops .ParseTimestampOp ().as_expr (arg )
328
+ # date -> string casting
329
+ if (
330
+ arg .output_type == dtypes .DATETIME_DTYPE
331
+ and cast_op .to_type == dtypes .STRING_DTYPE
332
+ ):
333
+ return datetime_ops .StrftimeOp ("%Y-%m-%d %H:%M:%S" ).as_expr (arg )
334
+ if arg .output_type == dtypes .TIME_DTYPE and cast_op .to_type == dtypes .STRING_DTYPE :
335
+ return datetime_ops .StrftimeOp ("%H:%M:%S.%6f" ).as_expr (arg )
336
+ if (
337
+ arg .output_type == dtypes .TIMESTAMP_DTYPE
338
+ and cast_op .to_type == dtypes .STRING_DTYPE
339
+ ):
340
+ return datetime_ops .StrftimeOp ("%Y-%m-%d %H:%M:%S%.6f%:::z" ).as_expr (arg )
341
+ if arg .output_type == dtypes .BOOL_DTYPE and cast_op .to_type == dtypes .STRING_DTYPE :
342
+ # bool -> decimal needs two-step cast
343
+ new_arg = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (arg )
344
+ is_true_cond = ops .eq_op .as_expr (arg , expression .const (True ))
345
+ is_false_cond = ops .eq_op .as_expr (arg , expression .const (False ))
346
+ return ops .CaseWhenOp ().as_expr (
347
+ is_true_cond ,
348
+ expression .const ("True" ),
349
+ is_false_cond ,
350
+ expression .const ("False" ),
351
+ )
304
352
if arg .output_type == dtypes .BOOL_DTYPE and dtypes .is_numeric (cast_op .to_type ):
305
353
# bool -> decimal needs two-step cast
306
354
new_arg = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (arg )
307
355
return cast_op .as_expr (new_arg )
356
+ if arg .output_type == dtypes .TIME_DTYPE and dtypes .is_numeric (cast_op .to_type ):
357
+ # polars cast gives nanoseconds, so convert to microseconds
358
+ return numeric_ops .floordiv_op .as_expr (
359
+ cast_op .as_expr (arg ), expression .const (1000 )
360
+ )
361
+ if dtypes .is_numeric (arg .output_type ) and cast_op .to_type == dtypes .TIME_DTYPE :
362
+ return cast_op .as_expr (ops .mul_op .as_expr (expression .const (1000 ), arg ))
308
363
return cast_op .as_expr (arg )
309
364
310
365
@@ -329,6 +384,7 @@ def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression):
329
384
LowerDivRule (),
330
385
LowerFloorDivRule (),
331
386
LowerModRule (),
387
+ LowerAsTypeRule (),
332
388
)
333
389
334
390
0 commit comments