@@ -38,26 +38,23 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
38
38
return sge .Concat (expressions = [left .expr , right .expr ])
39
39
40
40
if dtypes .is_numeric (left .dtype ) and dtypes .is_numeric (right .dtype ):
41
- left_expr , right_expr = _coerce_bools (left , right )
41
+ left_expr = _coerce_bool_to_int (left )
42
+ right_expr = _coerce_bool_to_int (right )
42
43
return sge .Add (this = left_expr , expression = right_expr )
43
44
44
45
if (
45
46
dtypes .is_time_or_date_like (left .dtype )
46
47
and right .dtype == dtypes .TIMEDELTA_DTYPE
47
48
):
48
- left_expr = left .expr
49
- if left .dtype == dtypes .DATE_DTYPE :
50
- left_expr = sge .Cast (this = left_expr , to = "DATETIME" )
49
+ left_expr = _coerce_date_to_datetime (left )
51
50
return sge .TimestampAdd (
52
51
this = left_expr , expression = right .expr , unit = sge .Var (this = "MICROSECOND" )
53
52
)
54
53
if (
55
54
dtypes .is_time_or_date_like (right .dtype )
56
55
and left .dtype == dtypes .TIMEDELTA_DTYPE
57
56
):
58
- right_expr = right .expr
59
- if right .dtype == dtypes .DATE_DTYPE :
60
- right_expr = sge .Cast (this = right_expr , to = "DATETIME" )
57
+ right_expr = _coerce_date_to_datetime (right )
61
58
return sge .TimestampAdd (
62
59
this = right_expr , expression = left .expr , unit = sge .Var (this = "MICROSECOND" )
63
60
)
@@ -71,19 +68,20 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
71
68
72
69
@BINARY_OP_REGISTRATION .register (ops .eq_op )
73
70
def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
74
- left_expr , right_expr = _coerce_bools (left , right )
71
+ left_expr = _coerce_bool_to_int (left )
72
+ right_expr = _coerce_bool_to_int (right )
75
73
return sge .EQ (this = left_expr , expression = right_expr )
76
74
77
75
78
76
@BINARY_OP_REGISTRATION .register (ops .eq_null_match_op )
79
77
def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
80
78
left_expr = left .expr
81
- if left . dtype == dtypes . BOOL_DTYPE and right .dtype != dtypes .BOOL_DTYPE :
82
- left_expr = sge . Cast ( this = left_expr , to = "INT64" )
79
+ if right .dtype != dtypes .BOOL_DTYPE :
80
+ left_expr = _coerce_bool_to_int ( left )
83
81
84
82
right_expr = right .expr
85
- if right . dtype == dtypes . BOOL_DTYPE and left .dtype != dtypes .BOOL_DTYPE :
86
- right_expr = sge . Cast ( this = right_expr , to = "INT64" )
83
+ if left .dtype != dtypes .BOOL_DTYPE :
84
+ right_expr = _coerce_bool_to_int ( right )
87
85
88
86
sentinel = sge .convert ("$NULL_SENTINEL$" )
89
87
left_coalesce = sge .Coalesce (
@@ -97,7 +95,8 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
97
95
98
96
@BINARY_OP_REGISTRATION .register (ops .div_op )
99
97
def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
100
- left_expr , right_expr = _coerce_bools (left , right )
98
+ left_expr = _coerce_bool_to_int (left )
99
+ right_expr = _coerce_bool_to_int (right )
101
100
102
101
result = sge .func ("IEEE_DIVIDE" , left_expr , right_expr )
103
102
if left .dtype == dtypes .TIMEDELTA_DTYPE and dtypes .is_numeric (right .dtype ):
@@ -108,12 +107,8 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
108
107
109
108
@BINARY_OP_REGISTRATION .register (ops .floordiv_op )
110
109
def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
111
- left_expr = left .expr
112
- if left .dtype == dtypes .BOOL_DTYPE :
113
- left_expr = sge .Cast (this = left_expr , to = "INT64" )
114
- right_expr = right .expr
115
- if right .dtype == dtypes .BOOL_DTYPE :
116
- right_expr = sge .Cast (this = right_expr , to = "INT64" )
110
+ left_expr = _coerce_bool_to_int (left )
111
+ right_expr = _coerce_bool_to_int (right )
117
112
118
113
result : sge .Expression = sge .Cast (
119
114
this = sge .Floor (this = sge .func ("IEEE_DIVIDE" , left_expr , right_expr )), to = "INT64"
@@ -155,7 +150,8 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
155
150
156
151
@BINARY_OP_REGISTRATION .register (ops .mul_op )
157
152
def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
158
- left_expr , right_expr = _coerce_bools (left , right )
153
+ left_expr = _coerce_bool_to_int (left )
154
+ right_expr = _coerce_bool_to_int (right )
159
155
160
156
result = sge .Mul (this = left_expr , expression = right_expr )
161
157
@@ -169,35 +165,31 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
169
165
170
166
@BINARY_OP_REGISTRATION .register (ops .ne_op )
171
167
def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
172
- left_expr , right_expr = _coerce_bools (left , right )
168
+ left_expr = _coerce_bool_to_int (left )
169
+ right_expr = _coerce_bool_to_int (right )
173
170
return sge .NEQ (this = left_expr , expression = right_expr )
174
171
175
172
176
173
@BINARY_OP_REGISTRATION .register (ops .sub_op )
177
174
def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
178
175
if dtypes .is_numeric (left .dtype ) and dtypes .is_numeric (right .dtype ):
179
- left_expr , right_expr = _coerce_bools (left , right )
176
+ left_expr = _coerce_bool_to_int (left )
177
+ right_expr = _coerce_bool_to_int (right )
180
178
return sge .Sub (this = left_expr , expression = right_expr )
181
179
182
180
if (
183
181
dtypes .is_time_or_date_like (left .dtype )
184
182
and right .dtype == dtypes .TIMEDELTA_DTYPE
185
183
):
186
- left_expr = left .expr
187
- if left .dtype == dtypes .DATE_DTYPE :
188
- left_expr = sge .Cast (this = left_expr , to = "DATETIME" )
184
+ left_expr = _coerce_date_to_datetime (left )
189
185
return sge .TimestampSub (
190
186
this = left_expr , expression = right .expr , unit = sge .Var (this = "MICROSECOND" )
191
187
)
192
188
if dtypes .is_time_or_date_like (left .dtype ) and dtypes .is_time_or_date_like (
193
189
right .dtype
194
190
):
195
- left_expr = left .expr
196
- if left .dtype == dtypes .DATE_DTYPE :
197
- left_expr = sge .Cast (this = left_expr , to = "DATETIME" )
198
- right_expr = right .expr
199
- if right .dtype == dtypes .DATE_DTYPE :
200
- right_expr = sge .Cast (this = right_expr , to = "DATETIME" )
191
+ left_expr = _coerce_date_to_datetime (left )
192
+ right_expr = _coerce_date_to_datetime (right )
201
193
return sge .TimestampDiff (
202
194
this = left_expr , expression = right_expr , unit = sge .Var (this = "MICROSECOND" )
203
195
)
@@ -215,14 +207,15 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
215
207
return sge .func ("OBJ.MAKE_REF" , left .expr , right .expr )
216
208
217
209
218
- def _coerce_bools (
219
- left : TypedExpr , right : TypedExpr
220
- ) -> tuple [sge .Expression , sge .Expression ]:
221
- """Coerce boolean expressions to INT64 for binary operations."""
222
- left_expr = left .expr
223
- if left .dtype == dtypes .BOOL_DTYPE :
224
- left_expr = sge .Cast (this = left_expr , to = "INT64" )
225
- right_expr = right .expr
226
- if right .dtype == dtypes .BOOL_DTYPE :
227
- right_expr = sge .Cast (this = right_expr , to = "INT64" )
228
- return left_expr , right_expr
210
+ def _coerce_bool_to_int (typed_expr : TypedExpr ) -> sge .Expression :
211
+ """Coerce boolean expression to integer."""
212
+ if typed_expr .dtype == dtypes .BOOL_DTYPE :
213
+ return sge .Cast (this = typed_expr .expr , to = "INT64" )
214
+ return typed_expr .expr
215
+
216
+
217
+ def _coerce_date_to_datetime (typed_expr : TypedExpr ) -> sge .Expression :
218
+ """Coerce date expression to datetime."""
219
+ if typed_expr .dtype == dtypes .DATE_DTYPE :
220
+ return sge .Cast (this = typed_expr .expr , to = "DATETIME" )
221
+ return typed_expr .expr
0 commit comments