@@ -37,26 +37,23 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
37
37
return sge .Concat (expressions = [left .expr , right .expr ])
38
38
39
39
if dtypes .is_numeric (left .dtype ) and dtypes .is_numeric (right .dtype ):
40
- left_expr , right_expr = _coerce_bools (left , right )
40
+ left_expr = left .coerce_bool_to_int ()
41
+ right_expr = right .coerce_bool_to_int ()
41
42
return sge .Add (this = left_expr , expression = right_expr )
42
43
43
44
if (
44
45
dtypes .is_time_or_date_like (left .dtype )
45
46
and right .dtype == dtypes .TIMEDELTA_DTYPE
46
47
):
47
- left_expr = left .expr
48
- if left .dtype == dtypes .DATE_DTYPE :
49
- left_expr = sge .Cast (this = left_expr , to = "DATETIME" )
48
+ left_expr = left .coerce_date_to_datetime ()
50
49
return sge .TimestampAdd (
51
50
this = left_expr , expression = right .expr , unit = sge .Var (this = "MICROSECOND" )
52
51
)
53
52
if (
54
53
dtypes .is_time_or_date_like (right .dtype )
55
54
and left .dtype == dtypes .TIMEDELTA_DTYPE
56
55
):
57
- right_expr = right .expr
58
- if right .dtype == dtypes .DATE_DTYPE :
59
- right_expr = sge .Cast (this = right_expr , to = "DATETIME" )
56
+ right_expr = right .coerce_date_to_datetime ()
60
57
return sge .TimestampAdd (
61
58
this = right_expr , expression = left .expr , unit = sge .Var (this = "MICROSECOND" )
62
59
)
@@ -70,19 +67,20 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
70
67
71
68
@BINARY_OP_REGISTRATION .register (ops .eq_op )
72
69
def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
73
- left_expr , right_expr = _coerce_bools (left , right )
70
+ left_expr = left .coerce_bool_to_int ()
71
+ right_expr = right .coerce_bool_to_int ()
74
72
return sge .EQ (this = left_expr , expression = right_expr )
75
73
76
74
77
75
@BINARY_OP_REGISTRATION .register (ops .eq_null_match_op )
78
76
def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
79
77
left_expr = left .expr
80
- if left . dtype == dtypes . BOOL_DTYPE and right .dtype != dtypes .BOOL_DTYPE :
81
- left_expr = sge . Cast ( this = left_expr , to = "INT64" )
78
+ if right .dtype != dtypes .BOOL_DTYPE :
79
+ left_expr = left . coerce_bool_to_int ( )
82
80
83
81
right_expr = right .expr
84
- if right . dtype == dtypes . BOOL_DTYPE and left .dtype != dtypes .BOOL_DTYPE :
85
- right_expr = sge . Cast ( this = right_expr , to = "INT64" )
82
+ if left .dtype != dtypes .BOOL_DTYPE :
83
+ right_expr = right . coerce_bool_to_int ( )
86
84
87
85
sentinel = sge .convert ("$NULL_SENTINEL$" )
88
86
left_coalesce = sge .Coalesce (
@@ -96,7 +94,8 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
96
94
97
95
@BINARY_OP_REGISTRATION .register (ops .div_op )
98
96
def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
99
- left_expr , right_expr = _coerce_bools (left , right )
97
+ left_expr = left .coerce_bool_to_int ()
98
+ right_expr = right .coerce_bool_to_int ()
100
99
101
100
result = sge .func ("IEEE_DIVIDE" , left_expr , right_expr )
102
101
if left .dtype == dtypes .TIMEDELTA_DTYPE and dtypes .is_numeric (right .dtype ):
@@ -117,7 +116,8 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
117
116
118
117
@BINARY_OP_REGISTRATION .register (ops .mul_op )
119
118
def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
120
- left_expr , right_expr = _coerce_bools (left , right )
119
+ left_expr = left .coerce_bool_to_int ()
120
+ right_expr = right .coerce_bool_to_int ()
121
121
122
122
result = sge .Mul (this = left_expr , expression = right_expr )
123
123
@@ -131,35 +131,31 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
131
131
132
132
@BINARY_OP_REGISTRATION .register (ops .ne_op )
133
133
def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
134
- left_expr , right_expr = _coerce_bools (left , right )
134
+ left_expr = left .coerce_bool_to_int ()
135
+ right_expr = right .coerce_bool_to_int ()
135
136
return sge .NEQ (this = left_expr , expression = right_expr )
136
137
137
138
138
139
@BINARY_OP_REGISTRATION .register (ops .sub_op )
139
140
def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
140
141
if dtypes .is_numeric (left .dtype ) and dtypes .is_numeric (right .dtype ):
141
- left_expr , right_expr = _coerce_bools (left , right )
142
+ left_expr = left .coerce_bool_to_int ()
143
+ right_expr = right .coerce_bool_to_int ()
142
144
return sge .Sub (this = left_expr , expression = right_expr )
143
145
144
146
if (
145
147
dtypes .is_time_or_date_like (left .dtype )
146
148
and right .dtype == dtypes .TIMEDELTA_DTYPE
147
149
):
148
- left_expr = left .expr
149
- if left .dtype == dtypes .DATE_DTYPE :
150
- left_expr = sge .Cast (this = left_expr , to = "DATETIME" )
150
+ left_expr = left .coerce_date_to_datetime ()
151
151
return sge .TimestampSub (
152
152
this = left_expr , expression = right .expr , unit = sge .Var (this = "MICROSECOND" )
153
153
)
154
154
if dtypes .is_time_or_date_like (left .dtype ) and dtypes .is_time_or_date_like (
155
155
right .dtype
156
156
):
157
- left_expr = left .expr
158
- if left .dtype == dtypes .DATE_DTYPE :
159
- left_expr = sge .Cast (this = left_expr , to = "DATETIME" )
160
- right_expr = right .expr
161
- if right .dtype == dtypes .DATE_DTYPE :
162
- right_expr = sge .Cast (this = right_expr , to = "DATETIME" )
157
+ left_expr = left .coerce_date_to_datetime ()
158
+ right_expr = right .coerce_date_to_datetime ()
163
159
return sge .TimestampDiff (
164
160
this = left_expr , expression = right_expr , unit = sge .Var (this = "MICROSECOND" )
165
161
)
@@ -175,16 +171,3 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
175
171
@BINARY_OP_REGISTRATION .register (ops .obj_make_ref_op )
176
172
def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
177
173
return sge .func ("OBJ.MAKE_REF" , left .expr , right .expr )
178
-
179
-
180
- def _coerce_bools (
181
- left : TypedExpr , right : TypedExpr
182
- ) -> tuple [sge .Expression , sge .Expression ]:
183
- """Coerce boolean expressions to INT64 for binary operations."""
184
- left_expr = left .expr
185
- if left .dtype == dtypes .BOOL_DTYPE :
186
- left_expr = sge .Cast (this = left_expr , to = "INT64" )
187
- right_expr = right .expr
188
- if right .dtype == dtypes .BOOL_DTYPE :
189
- right_expr = sge .Cast (this = right_expr , to = "INT64" )
190
- return left_expr , right_expr
0 commit comments