Skip to content

Commit 9d96811

Browse files
committed
address comments
1 parent 3664162 commit 9d96811

File tree

2 files changed

+33
-38
lines changed

2 files changed

+33
-38
lines changed

bigframes/core/compile/sqlglot/expressions/binary_compiler.py

Lines changed: 21 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -37,26 +37,23 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
3737
return sge.Concat(expressions=[left.expr, right.expr])
3838

3939
if dtypes.is_numeric(left.dtype) and dtypes.is_numeric(right.dtype):
40-
left_expr, right_expr = _coerce_bools(left, right)
40+
left_expr = left.coerce_bool_to_int()
41+
right_expr = right.coerce_bool_to_int()
4142
return sge.Add(this=left_expr, expression=right_expr)
4243

4344
if (
4445
dtypes.is_time_or_date_like(left.dtype)
4546
and right.dtype == dtypes.TIMEDELTA_DTYPE
4647
):
47-
left_expr = left.expr
48-
if left.dtype == dtypes.DATE_DTYPE:
49-
left_expr = sge.Cast(this=left_expr, to="DATETIME")
48+
left_expr = left.coerce_date_to_datetime()
5049
return sge.TimestampAdd(
5150
this=left_expr, expression=right.expr, unit=sge.Var(this="MICROSECOND")
5251
)
5352
if (
5453
dtypes.is_time_or_date_like(right.dtype)
5554
and left.dtype == dtypes.TIMEDELTA_DTYPE
5655
):
57-
right_expr = right.expr
58-
if right.dtype == dtypes.DATE_DTYPE:
59-
right_expr = sge.Cast(this=right_expr, to="DATETIME")
56+
right_expr = right.coerce_date_to_datetime()
6057
return sge.TimestampAdd(
6158
this=right_expr, expression=left.expr, unit=sge.Var(this="MICROSECOND")
6259
)
@@ -70,19 +67,20 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
7067

7168
@BINARY_OP_REGISTRATION.register(ops.eq_op)
7269
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
73-
left_expr, right_expr = _coerce_bools(left, right)
70+
left_expr = left.coerce_bool_to_int()
71+
right_expr = right.coerce_bool_to_int()
7472
return sge.EQ(this=left_expr, expression=right_expr)
7573

7674

7775
@BINARY_OP_REGISTRATION.register(ops.eq_null_match_op)
7876
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
7977
left_expr = left.expr
80-
if left.dtype == dtypes.BOOL_DTYPE and right.dtype != dtypes.BOOL_DTYPE:
81-
left_expr = sge.Cast(this=left_expr, to="INT64")
78+
if right.dtype != dtypes.BOOL_DTYPE:
79+
left_expr = left.coerce_bool_to_int()
8280

8381
right_expr = right.expr
84-
if right.dtype == dtypes.BOOL_DTYPE and left.dtype != dtypes.BOOL_DTYPE:
85-
right_expr = sge.Cast(this=right_expr, to="INT64")
82+
if left.dtype != dtypes.BOOL_DTYPE:
83+
right_expr = right.coerce_bool_to_int()
8684

8785
sentinel = sge.convert("$NULL_SENTINEL$")
8886
left_coalesce = sge.Coalesce(
@@ -96,7 +94,8 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
9694

9795
@BINARY_OP_REGISTRATION.register(ops.div_op)
9896
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
99-
left_expr, right_expr = _coerce_bools(left, right)
97+
left_expr = left.coerce_bool_to_int()
98+
right_expr = right.coerce_bool_to_int()
10099

101100
result = sge.func("IEEE_DIVIDE", left_expr, right_expr)
102101
if left.dtype == dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right.dtype):
@@ -117,7 +116,8 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
117116

118117
@BINARY_OP_REGISTRATION.register(ops.mul_op)
119118
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
120-
left_expr, right_expr = _coerce_bools(left, right)
119+
left_expr = left.coerce_bool_to_int()
120+
right_expr = right.coerce_bool_to_int()
121121

122122
result = sge.Mul(this=left_expr, expression=right_expr)
123123

@@ -131,35 +131,31 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
131131

132132
@BINARY_OP_REGISTRATION.register(ops.ne_op)
133133
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
134-
left_expr, right_expr = _coerce_bools(left, right)
134+
left_expr = left.coerce_bool_to_int()
135+
right_expr = right.coerce_bool_to_int()
135136
return sge.NEQ(this=left_expr, expression=right_expr)
136137

137138

138139
@BINARY_OP_REGISTRATION.register(ops.sub_op)
139140
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
140141
if dtypes.is_numeric(left.dtype) and dtypes.is_numeric(right.dtype):
141-
left_expr, right_expr = _coerce_bools(left, right)
142+
left_expr = left.coerce_bool_to_int()
143+
right_expr = right.coerce_bool_to_int()
142144
return sge.Sub(this=left_expr, expression=right_expr)
143145

144146
if (
145147
dtypes.is_time_or_date_like(left.dtype)
146148
and right.dtype == dtypes.TIMEDELTA_DTYPE
147149
):
148-
left_expr = left.expr
149-
if left.dtype == dtypes.DATE_DTYPE:
150-
left_expr = sge.Cast(this=left_expr, to="DATETIME")
150+
left_expr = left.coerce_date_to_datetime()
151151
return sge.TimestampSub(
152152
this=left_expr, expression=right.expr, unit=sge.Var(this="MICROSECOND")
153153
)
154154
if dtypes.is_time_or_date_like(left.dtype) and dtypes.is_time_or_date_like(
155155
right.dtype
156156
):
157-
left_expr = left.expr
158-
if left.dtype == dtypes.DATE_DTYPE:
159-
left_expr = sge.Cast(this=left_expr, to="DATETIME")
160-
right_expr = right.expr
161-
if right.dtype == dtypes.DATE_DTYPE:
162-
right_expr = sge.Cast(this=right_expr, to="DATETIME")
157+
left_expr = left.coerce_date_to_datetime()
158+
right_expr = right.coerce_date_to_datetime()
163159
return sge.TimestampDiff(
164160
this=left_expr, expression=right_expr, unit=sge.Var(this="MICROSECOND")
165161
)
@@ -175,16 +171,3 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
175171
@BINARY_OP_REGISTRATION.register(ops.obj_make_ref_op)
176172
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
177173
return sge.func("OBJ.MAKE_REF", left.expr, right.expr)
178-
179-
180-
def _coerce_bools(
181-
left: TypedExpr, right: TypedExpr
182-
) -> tuple[sge.Expression, sge.Expression]:
183-
"""Coerce boolean expressions to INT64 for binary operations."""
184-
left_expr = left.expr
185-
if left.dtype == dtypes.BOOL_DTYPE:
186-
left_expr = sge.Cast(this=left_expr, to="INT64")
187-
right_expr = right.expr
188-
if right.dtype == dtypes.BOOL_DTYPE:
189-
right_expr = sge.Cast(this=right_expr, to="INT64")
190-
return left_expr, right_expr

bigframes/core/compile/sqlglot/expressions/typed_expr.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,15 @@ class TypedExpr:
2525

2626
expr: sge.Expression
2727
dtype: dtypes.ExpressionType
28+
29+
def coerce_bool_to_int(self) -> sge.Expression:
30+
"""Coerce boolean expression to integer."""
31+
if self.dtype == dtypes.BOOL_DTYPE:
32+
return sge.Cast(this=self.expr, to="INT64")
33+
return self.expr
34+
35+
def coerce_date_to_datetime(self) -> sge.Expression:
36+
"""Coerce date expression to datetime."""
37+
if self.dtype == dtypes.DATE_DTYPE:
38+
return sge.Cast(this=self.expr, to="DATETIME")
39+
return self.expr

0 commit comments

Comments
 (0)