@@ -38,31 +38,23 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
38
38
return sge .Concat (expressions = [left .expr , right .expr ])
39
39
40
40
if dtypes .is_numeric (left .dtype ) and dtypes .is_numeric (right .dtype ):
41
- left_expr = left .expr
42
- if left .dtype == dtypes .BOOL_DTYPE :
43
- left_expr = sge .Cast (this = left_expr , to = "INT64" )
44
- right_expr = right .expr
45
- if right .dtype == dtypes .BOOL_DTYPE :
46
- right_expr = sge .Cast (this = right_expr , to = "INT64" )
41
+ left_expr = _coerce_bool_to_int (left )
42
+ right_expr = _coerce_bool_to_int (right )
47
43
return sge .Add (this = left_expr , expression = right_expr )
48
44
49
45
if (
50
46
dtypes .is_time_or_date_like (left .dtype )
51
47
and right .dtype == dtypes .TIMEDELTA_DTYPE
52
48
):
53
- left_expr = left .expr
54
- if left .dtype == dtypes .DATE_DTYPE :
55
- left_expr = sge .Cast (this = left_expr , to = "DATETIME" )
49
+ left_expr = _coerce_date_to_datetime (left )
56
50
return sge .TimestampAdd (
57
51
this = left_expr , expression = right .expr , unit = sge .Var (this = "MICROSECOND" )
58
52
)
59
53
if (
60
54
dtypes .is_time_or_date_like (right .dtype )
61
55
and left .dtype == dtypes .TIMEDELTA_DTYPE
62
56
):
63
- right_expr = right .expr
64
- if right .dtype == dtypes .DATE_DTYPE :
65
- right_expr = sge .Cast (this = right_expr , to = "DATETIME" )
57
+ right_expr = _coerce_date_to_datetime (right )
66
58
return sge .TimestampAdd (
67
59
this = right_expr , expression = left .expr , unit = sge .Var (this = "MICROSECOND" )
68
60
)
@@ -74,14 +66,37 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
74
66
)
75
67
76
68
77
- @BINARY_OP_REGISTRATION .register (ops .div_op )
69
+ @BINARY_OP_REGISTRATION .register (ops .eq_op )
70
+ def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
71
+ left_expr = _coerce_bool_to_int (left )
72
+ right_expr = _coerce_bool_to_int (right )
73
+ return sge .EQ (this = left_expr , expression = right_expr )
74
+
75
+
76
+ @BINARY_OP_REGISTRATION .register (ops .eq_null_match_op )
78
77
def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
79
78
left_expr = left .expr
80
- if left .dtype == dtypes .BOOL_DTYPE :
81
- left_expr = sge .Cast (this = left_expr , to = "INT64" )
79
+ if right .dtype != dtypes .BOOL_DTYPE :
80
+ left_expr = _coerce_bool_to_int (left )
81
+
82
82
right_expr = right .expr
83
- if right .dtype == dtypes .BOOL_DTYPE :
84
- right_expr = sge .Cast (this = right_expr , to = "INT64" )
83
+ if left .dtype != dtypes .BOOL_DTYPE :
84
+ right_expr = _coerce_bool_to_int (right )
85
+
86
+ sentinel = sge .convert ("$NULL_SENTINEL$" )
87
+ left_coalesce = sge .Coalesce (
88
+ this = sge .Cast (this = left_expr , to = "STRING" ), expressions = [sentinel ]
89
+ )
90
+ right_coalesce = sge .Coalesce (
91
+ this = sge .Cast (this = right_expr , to = "STRING" ), expressions = [sentinel ]
92
+ )
93
+ return sge .EQ (this = left_coalesce , expression = right_coalesce )
94
+
95
+
96
+ @BINARY_OP_REGISTRATION .register (ops .div_op )
97
+ def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
98
+ left_expr = _coerce_bool_to_int (left )
99
+ right_expr = _coerce_bool_to_int (right )
85
100
86
101
result = sge .func ("IEEE_DIVIDE" , left_expr , right_expr )
87
102
if left .dtype == dtypes .TIMEDELTA_DTYPE and dtypes .is_numeric (right .dtype ):
@@ -92,12 +107,8 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
92
107
93
108
@BINARY_OP_REGISTRATION .register (ops .floordiv_op )
94
109
def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
95
- left_expr = left .expr
96
- if left .dtype == dtypes .BOOL_DTYPE :
97
- left_expr = sge .Cast (this = left_expr , to = "INT64" )
98
- right_expr = right .expr
99
- if right .dtype == dtypes .BOOL_DTYPE :
100
- right_expr = sge .Cast (this = right_expr , to = "INT64" )
110
+ left_expr = _coerce_bool_to_int (left )
111
+ right_expr = _coerce_bool_to_int (right )
101
112
102
113
result : sge .Expression = sge .Cast (
103
114
this = sge .Floor (this = sge .func ("IEEE_DIVIDE" , left_expr , right_expr )), to = "INT64"
@@ -139,12 +150,8 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
139
150
140
151
@BINARY_OP_REGISTRATION .register (ops .mul_op )
141
152
def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
142
- left_expr = left .expr
143
- if left .dtype == dtypes .BOOL_DTYPE :
144
- left_expr = sge .Cast (this = left_expr , to = "INT64" )
145
- right_expr = right .expr
146
- if right .dtype == dtypes .BOOL_DTYPE :
147
- right_expr = sge .Cast (this = right_expr , to = "INT64" )
153
+ left_expr = _coerce_bool_to_int (left )
154
+ right_expr = _coerce_bool_to_int (right )
148
155
149
156
result = sge .Mul (this = left_expr , expression = right_expr )
150
157
@@ -156,36 +163,33 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
156
163
return result
157
164
158
165
166
+ @BINARY_OP_REGISTRATION .register (ops .ne_op )
167
+ def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
168
+ left_expr = _coerce_bool_to_int (left )
169
+ right_expr = _coerce_bool_to_int (right )
170
+ return sge .NEQ (this = left_expr , expression = right_expr )
171
+
172
+
159
173
@BINARY_OP_REGISTRATION .register (ops .sub_op )
160
174
def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
161
175
if dtypes .is_numeric (left .dtype ) and dtypes .is_numeric (right .dtype ):
162
- left_expr = left .expr
163
- if left .dtype == dtypes .BOOL_DTYPE :
164
- left_expr = sge .Cast (this = left_expr , to = "INT64" )
165
- right_expr = right .expr
166
- if right .dtype == dtypes .BOOL_DTYPE :
167
- right_expr = sge .Cast (this = right_expr , to = "INT64" )
176
+ left_expr = _coerce_bool_to_int (left )
177
+ right_expr = _coerce_bool_to_int (right )
168
178
return sge .Sub (this = left_expr , expression = right_expr )
169
179
170
180
if (
171
181
dtypes .is_time_or_date_like (left .dtype )
172
182
and right .dtype == dtypes .TIMEDELTA_DTYPE
173
183
):
174
- left_expr = left .expr
175
- if left .dtype == dtypes .DATE_DTYPE :
176
- left_expr = sge .Cast (this = left_expr , to = "DATETIME" )
184
+ left_expr = _coerce_date_to_datetime (left )
177
185
return sge .TimestampSub (
178
186
this = left_expr , expression = right .expr , unit = sge .Var (this = "MICROSECOND" )
179
187
)
180
188
if dtypes .is_time_or_date_like (left .dtype ) and dtypes .is_time_or_date_like (
181
189
right .dtype
182
190
):
183
- left_expr = left .expr
184
- if left .dtype == dtypes .DATE_DTYPE :
185
- left_expr = sge .Cast (this = left_expr , to = "DATETIME" )
186
- right_expr = right .expr
187
- if right .dtype == dtypes .DATE_DTYPE :
188
- right_expr = sge .Cast (this = right_expr , to = "DATETIME" )
191
+ left_expr = _coerce_date_to_datetime (left )
192
+ right_expr = _coerce_date_to_datetime (right )
189
193
return sge .TimestampDiff (
190
194
this = left_expr , expression = right_expr , unit = sge .Var (this = "MICROSECOND" )
191
195
)
@@ -201,3 +205,17 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
201
205
@BINARY_OP_REGISTRATION .register (ops .obj_make_ref_op )
202
206
def _ (op , left : TypedExpr , right : TypedExpr ) -> sge .Expression :
203
207
return sge .func ("OBJ.MAKE_REF" , left .expr , right .expr )
208
+
209
+
210
+ def _coerce_bool_to_int (typed_expr : TypedExpr ) -> sge .Expression :
211
+ """Coerce boolean expression to integer."""
212
+ if typed_expr .dtype == dtypes .BOOL_DTYPE :
213
+ return sge .Cast (this = typed_expr .expr , to = "INT64" )
214
+ return typed_expr .expr
215
+
216
+
217
+ def _coerce_date_to_datetime (typed_expr : TypedExpr ) -> sge .Expression :
218
+ """Coerce date expression to datetime."""
219
+ if typed_expr .dtype == dtypes .DATE_DTYPE :
220
+ return sge .Cast (this = typed_expr .expr , to = "DATETIME" )
221
+ return typed_expr .expr
0 commit comments