Skip to content

Commit 31814df

Browse files
committed
address comments
1 parent e199c75 commit 31814df

File tree

2 files changed

+35
-44
lines changed

2 files changed

+35
-44
lines changed

bigframes/core/compile/sqlglot/expressions/binary_compiler.py

Lines changed: 23 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -38,26 +38,23 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
3838
return sge.Concat(expressions=[left.expr, right.expr])
3939

4040
if dtypes.is_numeric(left.dtype) and dtypes.is_numeric(right.dtype):
41-
left_expr, right_expr = _coerce_bools(left, right)
41+
left_expr = left.coerce_bool_to_int()
42+
right_expr = right.coerce_bool_to_int()
4243
return sge.Add(this=left_expr, expression=right_expr)
4344

4445
if (
4546
dtypes.is_time_or_date_like(left.dtype)
4647
and right.dtype == dtypes.TIMEDELTA_DTYPE
4748
):
48-
left_expr = left.expr
49-
if left.dtype == dtypes.DATE_DTYPE:
50-
left_expr = sge.Cast(this=left_expr, to="DATETIME")
49+
left_expr = left.coerce_date_to_datetime()
5150
return sge.TimestampAdd(
5251
this=left_expr, expression=right.expr, unit=sge.Var(this="MICROSECOND")
5352
)
5453
if (
5554
dtypes.is_time_or_date_like(right.dtype)
5655
and left.dtype == dtypes.TIMEDELTA_DTYPE
5756
):
58-
right_expr = right.expr
59-
if right.dtype == dtypes.DATE_DTYPE:
60-
right_expr = sge.Cast(this=right_expr, to="DATETIME")
57+
right_expr = right.coerce_date_to_datetime()
6158
return sge.TimestampAdd(
6259
this=right_expr, expression=left.expr, unit=sge.Var(this="MICROSECOND")
6360
)
@@ -71,19 +68,20 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
7168

7269
@BINARY_OP_REGISTRATION.register(ops.eq_op)
7370
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
74-
left_expr, right_expr = _coerce_bools(left, right)
71+
left_expr = left.coerce_bool_to_int()
72+
right_expr = right.coerce_bool_to_int()
7573
return sge.EQ(this=left_expr, expression=right_expr)
7674

7775

7876
@BINARY_OP_REGISTRATION.register(ops.eq_null_match_op)
7977
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
8078
left_expr = left.expr
81-
if left.dtype == dtypes.BOOL_DTYPE and right.dtype != dtypes.BOOL_DTYPE:
82-
left_expr = sge.Cast(this=left_expr, to="INT64")
79+
if right.dtype != dtypes.BOOL_DTYPE:
80+
left_expr = left.coerce_bool_to_int()
8381

8482
right_expr = right.expr
85-
if right.dtype == dtypes.BOOL_DTYPE and left.dtype != dtypes.BOOL_DTYPE:
86-
right_expr = sge.Cast(this=right_expr, to="INT64")
83+
if left.dtype != dtypes.BOOL_DTYPE:
84+
right_expr = right.coerce_bool_to_int()
8785

8886
sentinel = sge.convert("$NULL_SENTINEL$")
8987
left_coalesce = sge.Coalesce(
@@ -97,7 +95,8 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
9795

9896
@BINARY_OP_REGISTRATION.register(ops.div_op)
9997
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
100-
left_expr, right_expr = _coerce_bools(left, right)
98+
left_expr = left.coerce_bool_to_int()
99+
right_expr = right.coerce_bool_to_int()
101100

102101
result = sge.func("IEEE_DIVIDE", left_expr, right_expr)
103102
if left.dtype == dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right.dtype):
@@ -108,12 +107,8 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
108107

109108
@BINARY_OP_REGISTRATION.register(ops.floordiv_op)
110109
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
111-
left_expr = left.expr
112-
if left.dtype == dtypes.BOOL_DTYPE:
113-
left_expr = sge.Cast(this=left_expr, to="INT64")
114-
right_expr = right.expr
115-
if right.dtype == dtypes.BOOL_DTYPE:
116-
right_expr = sge.Cast(this=right_expr, to="INT64")
110+
left_expr = left.coerce_bool_to_int()
111+
right_expr = right.coerce_bool_to_int()
117112

118113
result: sge.Expression = sge.Cast(
119114
this=sge.Floor(this=sge.func("IEEE_DIVIDE", left_expr, right_expr)), to="INT64"
@@ -155,7 +150,8 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
155150

156151
@BINARY_OP_REGISTRATION.register(ops.mul_op)
157152
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
158-
left_expr, right_expr = _coerce_bools(left, right)
153+
left_expr = left.coerce_bool_to_int()
154+
right_expr = right.coerce_bool_to_int()
159155

160156
result = sge.Mul(this=left_expr, expression=right_expr)
161157

@@ -169,35 +165,31 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
169165

170166
@BINARY_OP_REGISTRATION.register(ops.ne_op)
171167
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
172-
left_expr, right_expr = _coerce_bools(left, right)
168+
left_expr = left.coerce_bool_to_int()
169+
right_expr = right.coerce_bool_to_int()
173170
return sge.NEQ(this=left_expr, expression=right_expr)
174171

175172

176173
@BINARY_OP_REGISTRATION.register(ops.sub_op)
177174
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
178175
if dtypes.is_numeric(left.dtype) and dtypes.is_numeric(right.dtype):
179-
left_expr, right_expr = _coerce_bools(left, right)
176+
left_expr = left.coerce_bool_to_int()
177+
right_expr = right.coerce_bool_to_int()
180178
return sge.Sub(this=left_expr, expression=right_expr)
181179

182180
if (
183181
dtypes.is_time_or_date_like(left.dtype)
184182
and right.dtype == dtypes.TIMEDELTA_DTYPE
185183
):
186-
left_expr = left.expr
187-
if left.dtype == dtypes.DATE_DTYPE:
188-
left_expr = sge.Cast(this=left_expr, to="DATETIME")
184+
left_expr = left.coerce_date_to_datetime()
189185
return sge.TimestampSub(
190186
this=left_expr, expression=right.expr, unit=sge.Var(this="MICROSECOND")
191187
)
192188
if dtypes.is_time_or_date_like(left.dtype) and dtypes.is_time_or_date_like(
193189
right.dtype
194190
):
195-
left_expr = left.expr
196-
if left.dtype == dtypes.DATE_DTYPE:
197-
left_expr = sge.Cast(this=left_expr, to="DATETIME")
198-
right_expr = right.expr
199-
if right.dtype == dtypes.DATE_DTYPE:
200-
right_expr = sge.Cast(this=right_expr, to="DATETIME")
191+
left_expr = left.coerce_date_to_datetime()
192+
right_expr = right.coerce_date_to_datetime()
201193
return sge.TimestampDiff(
202194
this=left_expr, expression=right_expr, unit=sge.Var(this="MICROSECOND")
203195
)
@@ -213,16 +205,3 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
213205
@BINARY_OP_REGISTRATION.register(ops.obj_make_ref_op)
214206
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
215207
return sge.func("OBJ.MAKE_REF", left.expr, right.expr)
216-
217-
218-
def _coerce_bools(
219-
left: TypedExpr, right: TypedExpr
220-
) -> tuple[sge.Expression, sge.Expression]:
221-
"""Coerce boolean expressions to INT64 for binary operations."""
222-
left_expr = left.expr
223-
if left.dtype == dtypes.BOOL_DTYPE:
224-
left_expr = sge.Cast(this=left_expr, to="INT64")
225-
right_expr = right.expr
226-
if right.dtype == dtypes.BOOL_DTYPE:
227-
right_expr = sge.Cast(this=right_expr, to="INT64")
228-
return left_expr, right_expr

bigframes/core/compile/sqlglot/expressions/typed_expr.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,15 @@ class TypedExpr:
2525

2626
expr: sge.Expression
2727
dtype: dtypes.ExpressionType
28+
29+
def coerce_bool_to_int(self) -> sge.Expression:
30+
"""Coerce boolean expression to integer."""
31+
if self.dtype == dtypes.BOOL_DTYPE:
32+
return sge.Cast(this=self.expr, to="INT64")
33+
return self.expr
34+
35+
def coerce_date_to_datetime(self) -> sge.Expression:
36+
"""Coerce date expression to datetime."""
37+
if self.dtype == dtypes.DATE_DTYPE:
38+
return sge.Cast(this=self.expr, to="DATETIME")
39+
return self.expr

0 commit comments

Comments
 (0)