From 31b2d36ce047a2a7e352935f6ab2d1f3291b56b5 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Thu, 14 Aug 2025 23:31:22 +0000 Subject: [PATCH] chore: implement floordiv_op compiler --- .../sqlglot/expressions/binary_compiler.py | 44 ++++- .../compile/sqlglot/expressions/constants.py | 24 +++ .../sqlglot/expressions/unary_compiler.py | 39 ++--- .../system/small/engines/test_numeric_ops.py | 4 +- .../test_div_numeric/out.sql | 130 +++++++++++---- .../test_floordiv_numeric/out.sql | 154 ++++++++++++++++++ .../test_floordiv_timedelta/out.sql | 18 ++ .../expressions/test_binary_compiler.py | 33 +++- 8 files changed, 386 insertions(+), 60 deletions(-) create mode 100644 bigframes/core/compile/sqlglot/expressions/constants.py create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_floordiv_numeric/out.sql create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_floordiv_timedelta/out.sql diff --git a/bigframes/core/compile/sqlglot/expressions/binary_compiler.py b/bigframes/core/compile/sqlglot/expressions/binary_compiler.py index fa640ee0b2..d514c79f83 100644 --- a/bigframes/core/compile/sqlglot/expressions/binary_compiler.py +++ b/bigframes/core/compile/sqlglot/expressions/binary_compiler.py @@ -14,11 +14,12 @@ from __future__ import annotations -import bigframes_vendored.constants as constants +import bigframes_vendored.constants as bf_constants import sqlglot.expressions as sge from bigframes import dtypes from bigframes import operations as ops +import bigframes.core.compile.sqlglot.expressions.constants as constants from bigframes.core.compile.sqlglot.expressions.op_registration import OpRegistration from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr @@ -69,7 +70,7 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: return sge.Add(this=left.expr, expression=right.expr) raise TypeError( - f"Cannot add type {left.dtype} and {right.dtype}. {constants.FEEDBACK_LINK}" + f"Cannot add type {left.dtype} and {right.dtype}. {bf_constants.FEEDBACK_LINK}" ) @@ -89,6 +90,43 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: return result +@BINARY_OP_REGISTRATION.register(ops.floordiv_op) +def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: + left_expr = left.expr + if left.dtype == dtypes.BOOL_DTYPE: + left_expr = sge.Cast(this=left_expr, to="INT64") + right_expr = right.expr + if right.dtype == dtypes.BOOL_DTYPE: + right_expr = sge.Cast(this=right_expr, to="INT64") + + result: sge.Expression = sge.Cast( + this=sge.Floor(this=sge.func("IEEE_DIVIDE", left_expr, right_expr)), to="INT64" + ) + + # DIV(N, 0) will error in bigquery, but needs to return `0` for int, and + # `inf`` for float in BQ so we short-circuit in this case. + # Multiplying left by zero propogates nulls. + zero_result = ( + constants._INF + if (left.dtype == dtypes.FLOAT_DTYPE or right.dtype == dtypes.FLOAT_DTYPE) + else constants._ZERO + ) + result = sge.Case( + ifs=[ + sge.If( + this=sge.EQ(this=right_expr, expression=constants._ZERO), + true=zero_result * left_expr, + ) + ], + default=result, + ) + + if dtypes.is_numeric(right.dtype) and left.dtype == dtypes.TIMEDELTA_DTYPE: + result = sge.Cast(this=sge.Floor(this=result), to="INT64") + + return result + + @BINARY_OP_REGISTRATION.register(ops.ge_op) def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: return sge.GTE(this=left.expr, expression=right.expr) @@ -156,5 +194,5 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: return sge.Sub(this=left.expr, expression=right.expr) raise TypeError( - f"Cannot subtract type {left.dtype} and {right.dtype}. {constants.FEEDBACK_LINK}" + f"Cannot subtract type {left.dtype} and {right.dtype}. {bf_constants.FEEDBACK_LINK}" ) diff --git a/bigframes/core/compile/sqlglot/expressions/constants.py b/bigframes/core/compile/sqlglot/expressions/constants.py new file mode 100644 index 0000000000..f4ae9baca2 --- /dev/null +++ b/bigframes/core/compile/sqlglot/expressions/constants.py @@ -0,0 +1,24 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sqlglot.expressions as sge + +_ZERO = sge.Cast(this=sge.convert(0), to="INT64") +_NAN = sge.Cast(this=sge.convert("NaN"), to="FLOAT64") +_INF = sge.Cast(this=sge.convert("Infinity"), to="FLOAT64") + +# Approx Highest number you can pass in to EXP function and get a valid FLOAT64 result +# FLOAT64 has 11 exponent bits, so max values is about 2**(2**10) +# ln(2**(2**10)) == (2**10)*ln(2) ~= 709.78, so EXP(x) for x>709.78 will overflow. +_FLOAT64_EXP_BOUND = sge.convert(709.78) diff --git a/bigframes/core/compile/sqlglot/expressions/unary_compiler.py b/bigframes/core/compile/sqlglot/expressions/unary_compiler.py index 125c60bbf4..5c18441f8c 100644 --- a/bigframes/core/compile/sqlglot/expressions/unary_compiler.py +++ b/bigframes/core/compile/sqlglot/expressions/unary_compiler.py @@ -22,17 +22,10 @@ import sqlglot.expressions as sge from bigframes import operations as ops +import bigframes.core.compile.sqlglot.expressions.constants as constants from bigframes.core.compile.sqlglot.expressions.op_registration import OpRegistration from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr -_NAN = sge.Cast(this=sge.convert("NaN"), to="FLOAT64") -_INF = sge.Cast(this=sge.convert("Infinity"), to="FLOAT64") - -# Approx Highest number you can pass in to EXP function and get a valid FLOAT64 result -# FLOAT64 has 11 exponent bits, so max values is about 2**(2**10) -# ln(2**(2**10)) == (2**10)*ln(2) ~= 709.78, so EXP(x) for x>709.78 will overflow. -_FLOAT64_EXP_BOUND = sge.convert(709.78) - UNARY_OP_REGISTRATION = OpRegistration() @@ -51,7 +44,7 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: ifs=[ sge.If( this=expr.expr < sge.convert(1), - true=_NAN, + true=constants._NAN, ) ], default=sge.func("ACOSH", expr.expr), @@ -64,7 +57,7 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: ifs=[ sge.If( this=sge.func("ABS", expr.expr) > sge.convert(1), - true=_NAN, + true=constants._NAN, ) ], default=sge.func("ACOS", expr.expr), @@ -77,7 +70,7 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: ifs=[ sge.If( this=sge.func("ABS", expr.expr) > sge.convert(1), - true=_NAN, + true=constants._NAN, ) ], default=sge.func("ASIN", expr.expr), @@ -100,7 +93,7 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: ifs=[ sge.If( this=sge.func("ABS", expr.expr) > sge.convert(1), - true=_NAN, + true=constants._NAN, ) ], default=sge.func("ATANH", expr.expr), @@ -176,7 +169,7 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: ifs=[ sge.If( this=sge.func("ABS", expr.expr) > sge.convert(709.78), - true=_INF, + true=constants._INF, ) ], default=sge.func("COSH", expr.expr), @@ -221,8 +214,8 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: return sge.Case( ifs=[ sge.If( - this=expr.expr > _FLOAT64_EXP_BOUND, - true=_INF, + this=expr.expr > constants._FLOAT64_EXP_BOUND, + true=constants._INF, ) ], default=sge.func("EXP", expr.expr), @@ -234,8 +227,8 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: return sge.Case( ifs=[ sge.If( - this=expr.expr > _FLOAT64_EXP_BOUND, - true=_INF, + this=expr.expr > constants._FLOAT64_EXP_BOUND, + true=constants._INF, ) ], default=sge.func("EXP", expr.expr), @@ -382,7 +375,7 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: ifs=[ sge.If( this=expr.expr < sge.convert(0), - true=_NAN, + true=constants._NAN, ) ], default=sge.Ln(this=expr.expr), @@ -395,7 +388,7 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: ifs=[ sge.If( this=expr.expr < sge.convert(0), - true=_NAN, + true=constants._NAN, ) ], default=sge.Log(this=expr.expr, expression=sge.convert(10)), @@ -408,7 +401,7 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: ifs=[ sge.If( this=expr.expr < sge.convert(-1), - true=_NAN, + true=constants._NAN, ) ], default=sge.Ln(this=sge.convert(1) + expr.expr), @@ -476,7 +469,7 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: ifs=[ sge.If( this=expr.expr < sge.convert(0), - true=_NAN, + true=constants._NAN, ) ], default=sge.Sqrt(this=expr.expr), @@ -523,8 +516,8 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: return sge.Case( ifs=[ sge.If( - this=sge.func("ABS", expr.expr) > _FLOAT64_EXP_BOUND, - true=sge.func("SIGN", expr.expr) * _INF, + this=sge.func("ABS", expr.expr) > constants._FLOAT64_EXP_BOUND, + true=sge.func("SIGN", expr.expr) * constants._INF, ) ], default=sge.func("SINH", expr.expr), diff --git a/tests/system/small/engines/test_numeric_ops.py b/tests/system/small/engines/test_numeric_ops.py index b46a2f1c56..7928922e41 100644 --- a/tests/system/small/engines/test_numeric_ops.py +++ b/tests/system/small/engines/test_numeric_ops.py @@ -117,7 +117,7 @@ def test_engines_project_div_durations( assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_project_floordiv( scalars_array_value: array_value.ArrayValue, engine, @@ -130,7 +130,7 @@ def test_engines_project_floordiv( assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_project_floordiv_durations( scalars_array_value: array_value.ArrayValue, engine ): diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_div_numeric/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_div_numeric/out.sql index c1f4e0cb69..03d48276a0 100644 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_div_numeric/out.sql +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_div_numeric/out.sql @@ -2,53 +2,121 @@ WITH `bfcte_0` AS ( SELECT `bool_col` AS `bfcol_0`, `int64_col` AS `bfcol_1`, - `rowindex` AS `bfcol_2` + `float64_col` AS `bfcol_2`, + `rowindex` AS `bfcol_3` FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` ), `bfcte_1` AS ( SELECT *, - `bfcol_2` AS `bfcol_6`, - `bfcol_1` AS `bfcol_7`, - `bfcol_0` AS `bfcol_8`, - IEEE_DIVIDE(`bfcol_1`, `bfcol_1`) AS `bfcol_9` + `bfcol_3` AS `bfcol_8`, + `bfcol_1` AS `bfcol_9`, + `bfcol_0` AS `bfcol_10`, + `bfcol_2` AS `bfcol_11`, + IEEE_DIVIDE(`bfcol_1`, `bfcol_1`) AS `bfcol_12` FROM `bfcte_0` ), `bfcte_2` AS ( SELECT *, - `bfcol_6` AS `bfcol_14`, - `bfcol_7` AS `bfcol_15`, - `bfcol_8` AS `bfcol_16`, - `bfcol_9` AS `bfcol_17`, - IEEE_DIVIDE(`bfcol_7`, 1) AS `bfcol_18` + `bfcol_8` AS `bfcol_18`, + `bfcol_9` AS `bfcol_19`, + `bfcol_10` AS `bfcol_20`, + `bfcol_11` AS `bfcol_21`, + `bfcol_12` AS `bfcol_22`, + IEEE_DIVIDE(`bfcol_9`, 1) AS `bfcol_23` FROM `bfcte_1` ), `bfcte_3` AS ( SELECT *, - `bfcol_14` AS `bfcol_24`, - `bfcol_15` AS `bfcol_25`, - `bfcol_16` AS `bfcol_26`, - `bfcol_17` AS `bfcol_27`, - `bfcol_18` AS `bfcol_28`, - IEEE_DIVIDE(`bfcol_15`, CAST(`bfcol_16` AS INT64)) AS `bfcol_29` + `bfcol_18` AS `bfcol_30`, + `bfcol_19` AS `bfcol_31`, + `bfcol_20` AS `bfcol_32`, + `bfcol_21` AS `bfcol_33`, + `bfcol_22` AS `bfcol_34`, + `bfcol_23` AS `bfcol_35`, + IEEE_DIVIDE(`bfcol_19`, 0.0) AS `bfcol_36` FROM `bfcte_2` ), `bfcte_4` AS ( SELECT *, - `bfcol_24` AS `bfcol_36`, - `bfcol_25` AS `bfcol_37`, - `bfcol_26` AS `bfcol_38`, - `bfcol_27` AS `bfcol_39`, - `bfcol_28` AS `bfcol_40`, - `bfcol_29` AS `bfcol_41`, - IEEE_DIVIDE(CAST(`bfcol_26` AS INT64), `bfcol_25`) AS `bfcol_42` + `bfcol_30` AS `bfcol_44`, + `bfcol_31` AS `bfcol_45`, + `bfcol_32` AS `bfcol_46`, + `bfcol_33` AS `bfcol_47`, + `bfcol_34` AS `bfcol_48`, + `bfcol_35` AS `bfcol_49`, + `bfcol_36` AS `bfcol_50`, + IEEE_DIVIDE(`bfcol_31`, `bfcol_33`) AS `bfcol_51` FROM `bfcte_3` +), `bfcte_5` AS ( + SELECT + *, + `bfcol_44` AS `bfcol_60`, + `bfcol_45` AS `bfcol_61`, + `bfcol_46` AS `bfcol_62`, + `bfcol_47` AS `bfcol_63`, + `bfcol_48` AS `bfcol_64`, + `bfcol_49` AS `bfcol_65`, + `bfcol_50` AS `bfcol_66`, + `bfcol_51` AS `bfcol_67`, + IEEE_DIVIDE(`bfcol_47`, `bfcol_45`) AS `bfcol_68` + FROM `bfcte_4` +), `bfcte_6` AS ( + SELECT + *, + `bfcol_60` AS `bfcol_78`, + `bfcol_61` AS `bfcol_79`, + `bfcol_62` AS `bfcol_80`, + `bfcol_63` AS `bfcol_81`, + `bfcol_64` AS `bfcol_82`, + `bfcol_65` AS `bfcol_83`, + `bfcol_66` AS `bfcol_84`, + `bfcol_67` AS `bfcol_85`, + `bfcol_68` AS `bfcol_86`, + IEEE_DIVIDE(`bfcol_63`, 0.0) AS `bfcol_87` + FROM `bfcte_5` +), `bfcte_7` AS ( + SELECT + *, + `bfcol_78` AS `bfcol_98`, + `bfcol_79` AS `bfcol_99`, + `bfcol_80` AS `bfcol_100`, + `bfcol_81` AS `bfcol_101`, + `bfcol_82` AS `bfcol_102`, + `bfcol_83` AS `bfcol_103`, + `bfcol_84` AS `bfcol_104`, + `bfcol_85` AS `bfcol_105`, + `bfcol_86` AS `bfcol_106`, + `bfcol_87` AS `bfcol_107`, + IEEE_DIVIDE(`bfcol_79`, CAST(`bfcol_80` AS INT64)) AS `bfcol_108` + FROM `bfcte_6` +), `bfcte_8` AS ( + SELECT + *, + `bfcol_98` AS `bfcol_120`, + `bfcol_99` AS `bfcol_121`, + `bfcol_100` AS `bfcol_122`, + `bfcol_101` AS `bfcol_123`, + `bfcol_102` AS `bfcol_124`, + `bfcol_103` AS `bfcol_125`, + `bfcol_104` AS `bfcol_126`, + `bfcol_105` AS `bfcol_127`, + `bfcol_106` AS `bfcol_128`, + `bfcol_107` AS `bfcol_129`, + `bfcol_108` AS `bfcol_130`, + IEEE_DIVIDE(CAST(`bfcol_100` AS INT64), `bfcol_99`) AS `bfcol_131` + FROM `bfcte_7` ) SELECT - `bfcol_36` AS `rowindex`, - `bfcol_37` AS `int64_col`, - `bfcol_38` AS `bool_col`, - `bfcol_39` AS `int_div_int`, - `bfcol_40` AS `int_div_1`, - `bfcol_41` AS `int_div_bool`, - `bfcol_42` AS `bool_div_int` -FROM `bfcte_4` \ No newline at end of file + `bfcol_120` AS `rowindex`, + `bfcol_121` AS `int64_col`, + `bfcol_122` AS `bool_col`, + `bfcol_123` AS `float64_col`, + `bfcol_124` AS `int_div_int`, + `bfcol_125` AS `int_div_1`, + `bfcol_126` AS `int_div_0`, + `bfcol_127` AS `int_div_float`, + `bfcol_128` AS `float_div_int`, + `bfcol_129` AS `float_div_0`, + `bfcol_130` AS `int_div_bool`, + `bfcol_131` AS `bool_div_int` +FROM `bfcte_8` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_floordiv_numeric/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_floordiv_numeric/out.sql new file mode 100644 index 0000000000..c38bc18523 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_floordiv_numeric/out.sql @@ -0,0 +1,154 @@ +WITH `bfcte_0` AS ( + SELECT + `bool_col` AS `bfcol_0`, + `int64_col` AS `bfcol_1`, + `float64_col` AS `bfcol_2`, + `rowindex` AS `bfcol_3` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + `bfcol_3` AS `bfcol_8`, + `bfcol_1` AS `bfcol_9`, + `bfcol_0` AS `bfcol_10`, + `bfcol_2` AS `bfcol_11`, + CASE + WHEN `bfcol_1` = CAST(0 AS INT64) + THEN CAST(0 AS INT64) * `bfcol_1` + ELSE CAST(FLOOR(IEEE_DIVIDE(`bfcol_1`, `bfcol_1`)) AS INT64) + END AS `bfcol_12` + FROM `bfcte_0` +), `bfcte_2` AS ( + SELECT + *, + `bfcol_8` AS `bfcol_18`, + `bfcol_9` AS `bfcol_19`, + `bfcol_10` AS `bfcol_20`, + `bfcol_11` AS `bfcol_21`, + `bfcol_12` AS `bfcol_22`, + CASE + WHEN 1 = CAST(0 AS INT64) + THEN CAST(0 AS INT64) * `bfcol_9` + ELSE CAST(FLOOR(IEEE_DIVIDE(`bfcol_9`, 1)) AS INT64) + END AS `bfcol_23` + FROM `bfcte_1` +), `bfcte_3` AS ( + SELECT + *, + `bfcol_18` AS `bfcol_30`, + `bfcol_19` AS `bfcol_31`, + `bfcol_20` AS `bfcol_32`, + `bfcol_21` AS `bfcol_33`, + `bfcol_22` AS `bfcol_34`, + `bfcol_23` AS `bfcol_35`, + CASE + WHEN 0.0 = CAST(0 AS INT64) + THEN CAST('Infinity' AS FLOAT64) * `bfcol_19` + ELSE CAST(FLOOR(IEEE_DIVIDE(`bfcol_19`, 0.0)) AS INT64) + END AS `bfcol_36` + FROM `bfcte_2` +), `bfcte_4` AS ( + SELECT + *, + `bfcol_30` AS `bfcol_44`, + `bfcol_31` AS `bfcol_45`, + `bfcol_32` AS `bfcol_46`, + `bfcol_33` AS `bfcol_47`, + `bfcol_34` AS `bfcol_48`, + `bfcol_35` AS `bfcol_49`, + `bfcol_36` AS `bfcol_50`, + CASE + WHEN `bfcol_33` = CAST(0 AS INT64) + THEN CAST('Infinity' AS FLOAT64) * `bfcol_31` + ELSE CAST(FLOOR(IEEE_DIVIDE(`bfcol_31`, `bfcol_33`)) AS INT64) + END AS `bfcol_51` + FROM `bfcte_3` +), `bfcte_5` AS ( + SELECT + *, + `bfcol_44` AS `bfcol_60`, + `bfcol_45` AS `bfcol_61`, + `bfcol_46` AS `bfcol_62`, + `bfcol_47` AS `bfcol_63`, + `bfcol_48` AS `bfcol_64`, + `bfcol_49` AS `bfcol_65`, + `bfcol_50` AS `bfcol_66`, + `bfcol_51` AS `bfcol_67`, + CASE + WHEN `bfcol_45` = CAST(0 AS INT64) + THEN CAST('Infinity' AS FLOAT64) * `bfcol_47` + ELSE CAST(FLOOR(IEEE_DIVIDE(`bfcol_47`, `bfcol_45`)) AS INT64) + END AS `bfcol_68` + FROM `bfcte_4` +), `bfcte_6` AS ( + SELECT + *, + `bfcol_60` AS `bfcol_78`, + `bfcol_61` AS `bfcol_79`, + `bfcol_62` AS `bfcol_80`, + `bfcol_63` AS `bfcol_81`, + `bfcol_64` AS `bfcol_82`, + `bfcol_65` AS `bfcol_83`, + `bfcol_66` AS `bfcol_84`, + `bfcol_67` AS `bfcol_85`, + `bfcol_68` AS `bfcol_86`, + CASE + WHEN 0.0 = CAST(0 AS INT64) + THEN CAST('Infinity' AS FLOAT64) * `bfcol_63` + ELSE CAST(FLOOR(IEEE_DIVIDE(`bfcol_63`, 0.0)) AS INT64) + END AS `bfcol_87` + FROM `bfcte_5` +), `bfcte_7` AS ( + SELECT + *, + `bfcol_78` AS `bfcol_98`, + `bfcol_79` AS `bfcol_99`, + `bfcol_80` AS `bfcol_100`, + `bfcol_81` AS `bfcol_101`, + `bfcol_82` AS `bfcol_102`, + `bfcol_83` AS `bfcol_103`, + `bfcol_84` AS `bfcol_104`, + `bfcol_85` AS `bfcol_105`, + `bfcol_86` AS `bfcol_106`, + `bfcol_87` AS `bfcol_107`, + CASE + WHEN CAST(`bfcol_80` AS INT64) = CAST(0 AS INT64) + THEN CAST(0 AS INT64) * `bfcol_79` + ELSE CAST(FLOOR(IEEE_DIVIDE(`bfcol_79`, CAST(`bfcol_80` AS INT64))) AS INT64) + END AS `bfcol_108` + FROM `bfcte_6` +), `bfcte_8` AS ( + SELECT + *, + `bfcol_98` AS `bfcol_120`, + `bfcol_99` AS `bfcol_121`, + `bfcol_100` AS `bfcol_122`, + `bfcol_101` AS `bfcol_123`, + `bfcol_102` AS `bfcol_124`, + `bfcol_103` AS `bfcol_125`, + `bfcol_104` AS `bfcol_126`, + `bfcol_105` AS `bfcol_127`, + `bfcol_106` AS `bfcol_128`, + `bfcol_107` AS `bfcol_129`, + `bfcol_108` AS `bfcol_130`, + CASE + WHEN `bfcol_99` = CAST(0 AS INT64) + THEN CAST(0 AS INT64) * CAST(`bfcol_100` AS INT64) + ELSE CAST(FLOOR(IEEE_DIVIDE(CAST(`bfcol_100` AS INT64), `bfcol_99`)) AS INT64) + END AS `bfcol_131` + FROM `bfcte_7` +) +SELECT + `bfcol_120` AS `rowindex`, + `bfcol_121` AS `int64_col`, + `bfcol_122` AS `bool_col`, + `bfcol_123` AS `float64_col`, + `bfcol_124` AS `int_div_int`, + `bfcol_125` AS `int_div_1`, + `bfcol_126` AS `int_div_0`, + `bfcol_127` AS `int_div_float`, + `bfcol_128` AS `float_div_int`, + `bfcol_129` AS `float_div_0`, + `bfcol_130` AS `int_div_bool`, + `bfcol_131` AS `bool_div_int` +FROM `bfcte_8` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_floordiv_timedelta/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_floordiv_timedelta/out.sql new file mode 100644 index 0000000000..bc4f94d306 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_floordiv_timedelta/out.sql @@ -0,0 +1,18 @@ +WITH `bfcte_0` AS ( + SELECT + `date_col` AS `bfcol_0`, + `rowindex` AS `bfcol_1`, + `timestamp_col` AS `bfcol_2` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + 43200000000 AS `bfcol_6` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `rowindex`, + `bfcol_2` AS `timestamp_col`, + `bfcol_0` AS `date_col`, + `bfcol_6` AS `timedelta_div_numeric` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/test_binary_compiler.py b/tests/unit/core/compile/sqlglot/expressions/test_binary_compiler.py index 6521a92df0..a5b59ae24b 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_binary_compiler.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_binary_compiler.py @@ -83,10 +83,15 @@ def test_add_unsupported_raises(scalar_types_df: bpd.DataFrame): def test_div_numeric(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["int64_col", "bool_col"]] + bf_df = scalar_types_df[["int64_col", "bool_col", "float64_col"]] bf_df["int_div_int"] = bf_df["int64_col"] / bf_df["int64_col"] bf_df["int_div_1"] = bf_df["int64_col"] / 1 + bf_df["int_div_0"] = bf_df["int64_col"] / 0.0 + + bf_df["int_div_float"] = bf_df["int64_col"] / bf_df["float64_col"] + bf_df["float_div_int"] = bf_df["float64_col"] / bf_df["int64_col"] + bf_df["float_div_0"] = bf_df["float64_col"] / 0.0 bf_df["int_div_bool"] = bf_df["int64_col"] / bf_df["bool_col"] bf_df["bool_div_int"] = bf_df["bool_col"] / bf_df["int64_col"] @@ -102,6 +107,32 @@ def test_div_timedelta(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(bf_df.sql, "out.sql") +def test_floordiv_numeric(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["int64_col", "bool_col", "float64_col"]] + + bf_df["int_div_int"] = bf_df["int64_col"] // bf_df["int64_col"] + bf_df["int_div_1"] = bf_df["int64_col"] // 1 + bf_df["int_div_0"] = bf_df["int64_col"] // 0.0 + + bf_df["int_div_float"] = bf_df["int64_col"] // bf_df["float64_col"] + bf_df["float_div_int"] = bf_df["float64_col"] // bf_df["int64_col"] + bf_df["float_div_0"] = bf_df["float64_col"] // 0.0 + + bf_df["int_div_bool"] = bf_df["int64_col"] // bf_df["bool_col"] + bf_df["bool_div_int"] = bf_df["bool_col"] // bf_df["int64_col"] + + snapshot.assert_match(bf_df.sql, "out.sql") + + +def test_floordiv_timedelta(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["timestamp_col", "date_col"]] + timedelta = pd.Timedelta(1, unit="d") + + bf_df["timedelta_div_numeric"] = timedelta // 2 + + snapshot.assert_match(bf_df.sql, "out.sql") + + def test_json_set(json_types_df: bpd.DataFrame, snapshot): bf_df = json_types_df[["json_col"]] sql = _apply_binary_op(