From 84582b82a905831dc59cf9f54fc3e0966fe806db Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Fri, 15 Aug 2025 23:21:35 +0000 Subject: [PATCH] chore: implement comparison_ops for sqlglot compiler --- .../compile/sqlglot/expressions/binary_compiler.py | 11 +++++++++++ tests/system/small/engines/test_comparison_ops.py | 3 ++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/bigframes/core/compile/sqlglot/expressions/binary_compiler.py b/bigframes/core/compile/sqlglot/expressions/binary_compiler.py index b5d665e2e5..fc0e59fc7a 100644 --- a/bigframes/core/compile/sqlglot/expressions/binary_compiler.py +++ b/bigframes/core/compile/sqlglot/expressions/binary_compiler.py @@ -73,6 +73,17 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: ) +@BINARY_OP_REGISTRATION.register(ops.eq_op) +def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: + left_expr = left.expr + if left.dtype == dtypes.BOOL_DTYPE: + left_expr = sge.Cast(this=left_expr, to="INT64") + right_expr = right.expr + if right.dtype == dtypes.BOOL_DTYPE: + right_expr = sge.Cast(this=right_expr, to="INT64") + return sge.EQ(this=left_expr, expression=right_expr) + + @BINARY_OP_REGISTRATION.register(ops.div_op) def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: left_expr = left.expr diff --git a/tests/system/small/engines/test_comparison_ops.py b/tests/system/small/engines/test_comparison_ops.py index fefff93f58..6b97c8bfa8 100644 --- a/tests/system/small/engines/test_comparison_ops.py +++ b/tests/system/small/engines/test_comparison_ops.py @@ -48,7 +48,8 @@ def apply_op_pairwise( return new_arr -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +# @pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize("engine", ["bq-sqlglot"], indirect=True) @pytest.mark.parametrize( "op", [