@@ -37,26 +37,259 @@ def lower(self, expr: expression.OpExpression) -> expression.Expression:
37
37
return expr .op .as_expr (larg , rarg )
38
38
39
39
40
+ class LowerAddRule (op_lowering .OpLoweringRule ):
41
+ @property
42
+ def op (self ) -> type [ops .ScalarOp ]:
43
+ return numeric_ops .AddOp
44
+
45
+ def lower (self , expr : expression .OpExpression ) -> expression .Expression :
46
+ assert isinstance (expr .op , numeric_ops .AddOp )
47
+ larg , rarg = expr .children [0 ], expr .children [1 ]
48
+
49
+ if (
50
+ larg .output_type == dtypes .BOOL_DTYPE
51
+ and rarg .output_type == dtypes .BOOL_DTYPE
52
+ ):
53
+ int_result = expr .op .as_expr (
54
+ ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (larg ),
55
+ ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (rarg ),
56
+ )
57
+ return ops .AsTypeOp (to_type = dtypes .BOOL_DTYPE ).as_expr (int_result )
58
+
59
+ if dtypes .is_string_like (larg .output_type ) and dtypes .is_string_like (
60
+ rarg .output_type
61
+ ):
62
+ return ops .strconcat_op .as_expr (larg , rarg )
63
+
64
+ if larg .output_type == dtypes .BOOL_DTYPE :
65
+ larg = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (larg )
66
+ if rarg .output_type == dtypes .BOOL_DTYPE :
67
+ rarg = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (rarg )
68
+
69
+ if (
70
+ larg .output_type == dtypes .DATE_DTYPE
71
+ and rarg .output_type == dtypes .TIMEDELTA_DTYPE
72
+ ):
73
+ larg = ops .AsTypeOp (to_type = dtypes .DATETIME_DTYPE ).as_expr (larg )
74
+
75
+ if (
76
+ larg .output_type == dtypes .TIMEDELTA_DTYPE
77
+ and rarg .output_type == dtypes .DATE_DTYPE
78
+ ):
79
+ rarg = ops .AsTypeOp (to_type = dtypes .DATETIME_DTYPE ).as_expr (rarg )
80
+
81
+ return expr .op .as_expr (larg , rarg )
82
+
83
+
84
+ class LowerSubRule (op_lowering .OpLoweringRule ):
85
+ @property
86
+ def op (self ) -> type [ops .ScalarOp ]:
87
+ return numeric_ops .SubOp
88
+
89
+ def lower (self , expr : expression .OpExpression ) -> expression .Expression :
90
+ assert isinstance (expr .op , numeric_ops .SubOp )
91
+ larg , rarg = expr .children [0 ], expr .children [1 ]
92
+
93
+ if (
94
+ larg .output_type == dtypes .BOOL_DTYPE
95
+ and rarg .output_type == dtypes .BOOL_DTYPE
96
+ ):
97
+ int_result = expr .op .as_expr (
98
+ ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (larg ),
99
+ ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (rarg ),
100
+ )
101
+ return ops .AsTypeOp (to_type = dtypes .BOOL_DTYPE ).as_expr (int_result )
102
+
103
+ if larg .output_type == dtypes .BOOL_DTYPE :
104
+ larg = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (larg )
105
+ if rarg .output_type == dtypes .BOOL_DTYPE :
106
+ rarg = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (rarg )
107
+
108
+ if (
109
+ larg .output_type == dtypes .DATE_DTYPE
110
+ and rarg .output_type == dtypes .TIMEDELTA_DTYPE
111
+ ):
112
+ larg = ops .AsTypeOp (to_type = dtypes .DATETIME_DTYPE ).as_expr (larg )
113
+
114
+ return expr .op .as_expr (larg , rarg )
115
+
116
+
117
+ @dataclasses .dataclass
118
+ class LowerMulRule (op_lowering .OpLoweringRule ):
119
+ @property
120
+ def op (self ) -> type [ops .ScalarOp ]:
121
+ return numeric_ops .MulOp
122
+
123
+ def lower (self , expr : expression .OpExpression ) -> expression .Expression :
124
+ assert isinstance (expr .op , numeric_ops .MulOp )
125
+ larg , rarg = expr .children [0 ], expr .children [1 ]
126
+
127
+ if (
128
+ larg .output_type == dtypes .BOOL_DTYPE
129
+ and rarg .output_type == dtypes .BOOL_DTYPE
130
+ ):
131
+ int_result = expr .op .as_expr (
132
+ ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (larg ),
133
+ ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (rarg ),
134
+ )
135
+ return ops .AsTypeOp (to_type = dtypes .BOOL_DTYPE ).as_expr (int_result )
136
+
137
+ if (
138
+ larg .output_type == dtypes .BOOL_DTYPE
139
+ and rarg .output_type != dtypes .BOOL_DTYPE
140
+ ):
141
+ larg = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (larg )
142
+ if (
143
+ rarg .output_type == dtypes .BOOL_DTYPE
144
+ and larg .output_type != dtypes .BOOL_DTYPE
145
+ ):
146
+ rarg = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (rarg )
147
+
148
+ return expr .op .as_expr (larg , rarg )
149
+
150
+
151
+ class LowerDivRule (op_lowering .OpLoweringRule ):
152
+ @property
153
+ def op (self ) -> type [ops .ScalarOp ]:
154
+ return numeric_ops .DivOp
155
+
156
+ def lower (self , expr : expression .OpExpression ) -> expression .Expression :
157
+ assert isinstance (expr .op , numeric_ops .DivOp )
158
+
159
+ dividend = expr .children [0 ]
160
+ divisor = expr .children [1 ]
161
+
162
+ if dividend .output_type == dtypes .TIMEDELTA_DTYPE and dtypes .is_numeric (
163
+ divisor .output_type
164
+ ):
165
+ # exact same as floordiv impl for timedelta
166
+ numeric_result = ops .floordiv_op .as_expr (
167
+ ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (dividend ), divisor
168
+ )
169
+ int_result = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (numeric_result )
170
+ return ops .AsTypeOp (to_type = dtypes .TIMEDELTA_DTYPE ).as_expr (int_result )
171
+
172
+ if (
173
+ dividend .output_type == dtypes .BOOL_DTYPE
174
+ and divisor .output_type == dtypes .BOOL_DTYPE
175
+ ):
176
+ int_result = expr .op .as_expr (
177
+ ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (dividend ),
178
+ ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (divisor ),
179
+ )
180
+ return ops .AsTypeOp (to_type = dtypes .BOOL_DTYPE ).as_expr (int_result )
181
+
182
+ # polars divide doesn't like bools, convert to int always
183
+ # convert numerics to float always
184
+ if dividend .output_type == dtypes .BOOL_DTYPE :
185
+ dividend = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (dividend )
186
+ elif dividend .output_type in (dtypes .BIGNUMERIC_DTYPE , dtypes .NUMERIC_DTYPE ):
187
+ dividend = ops .AsTypeOp (to_type = dtypes .FLOAT_DTYPE ).as_expr (dividend )
188
+ if divisor .output_type == dtypes .BOOL_DTYPE :
189
+ divisor = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (divisor )
190
+
191
+ return numeric_ops .div_op .as_expr (dividend , divisor )
192
+
193
+
40
194
class LowerFloorDivRule (op_lowering .OpLoweringRule ):
41
195
@property
42
196
def op (self ) -> type [ops .ScalarOp ]:
43
197
return numeric_ops .FloorDivOp
44
198
45
199
def lower (self , expr : expression .OpExpression ) -> expression .Expression :
200
+ assert isinstance (expr .op , numeric_ops .FloorDivOp )
201
+
46
202
dividend = expr .children [0 ]
47
203
divisor = expr .children [1 ]
48
- using_floats = (dividend .output_type == dtypes .FLOAT_DTYPE ) or (
49
- divisor .output_type == dtypes .FLOAT_DTYPE
50
- )
51
- inf_or_zero = (
52
- expression .const (float ("INF" )) if using_floats else expression .const (0 )
53
- )
54
- zero_result = ops .mul_op .as_expr (inf_or_zero , dividend )
55
- divisor_is_zero = ops .eq_op .as_expr (divisor , expression .const (0 ))
56
- return ops .where_op .as_expr (zero_result , divisor_is_zero , expr )
204
+
205
+ if (
206
+ dividend .output_type == dtypes .TIMEDELTA_DTYPE
207
+ and divisor .output_type == dtypes .TIMEDELTA_DTYPE
208
+ ):
209
+ int_result = expr .op .as_expr (
210
+ ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (dividend ),
211
+ ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (divisor ),
212
+ )
213
+ return int_result
214
+ if dividend .output_type == dtypes .TIMEDELTA_DTYPE and dtypes .is_numeric (
215
+ divisor .output_type
216
+ ):
217
+ # this is pretty fragile as zero will break it, and must fit back into int
218
+ numeric_result = expr .op .as_expr (
219
+ ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (dividend ), divisor
220
+ )
221
+ int_result = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (numeric_result )
222
+ return ops .AsTypeOp (to_type = dtypes .TIMEDELTA_DTYPE ).as_expr (int_result )
223
+
224
+ if dividend .output_type == dtypes .BOOL_DTYPE :
225
+ dividend = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (dividend )
226
+ if divisor .output_type == dtypes .BOOL_DTYPE :
227
+ divisor = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (divisor )
228
+
229
+ if expr .output_type != dtypes .FLOAT_DTYPE :
230
+ # need to guard against zero divisor
231
+ # multiply dividend in this case to propagate nulls
232
+ return ops .where_op .as_expr (
233
+ ops .mul_op .as_expr (dividend , expression .const (0 )),
234
+ ops .eq_op .as_expr (divisor , expression .const (0 )),
235
+ numeric_ops .floordiv_op .as_expr (dividend , divisor ),
236
+ )
237
+ else :
238
+ return expr .op .as_expr (dividend , divisor )
239
+
240
+
241
+ class LowerModRule (op_lowering .OpLoweringRule ):
242
+ @property
243
+ def op (self ) -> type [ops .ScalarOp ]:
244
+ return numeric_ops .ModOp
245
+
246
+ def lower (self , expr : expression .OpExpression ) -> expression .Expression :
247
+ og_expr = expr
248
+ assert isinstance (expr .op , numeric_ops .ModOp )
249
+ larg , rarg = expr .children [0 ], expr .children [1 ]
250
+
251
+ if (
252
+ larg .output_type == dtypes .TIMEDELTA_DTYPE
253
+ and rarg .output_type == dtypes .TIMEDELTA_DTYPE
254
+ ):
255
+ larg_int = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (larg )
256
+ rarg_int = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (rarg )
257
+ int_result = expr .op .as_expr (larg_int , rarg_int )
258
+ w_zero_handling = ops .where_op .as_expr (
259
+ int_result ,
260
+ ops .ne_op .as_expr (rarg_int , expression .const (0 )),
261
+ ops .mul_op .as_expr (rarg_int , expression .const (0 )),
262
+ )
263
+ return ops .AsTypeOp (to_type = dtypes .TIMEDELTA_DTYPE ).as_expr (w_zero_handling )
264
+
265
+ if larg .output_type == dtypes .BOOL_DTYPE :
266
+ larg = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (larg )
267
+ if rarg .output_type == dtypes .BOOL_DTYPE :
268
+ rarg = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (rarg )
269
+
270
+ wo_bools = expr .op .as_expr (larg , rarg )
271
+
272
+ if og_expr .output_type == dtypes .INT_DTYPE :
273
+ return ops .where_op .as_expr (
274
+ wo_bools ,
275
+ ops .ne_op .as_expr (rarg , expression .const (0 )),
276
+ ops .mul_op .as_expr (rarg , expression .const (0 )),
277
+ )
278
+ return wo_bools
57
279
58
280
59
- def _coerce_comparables (expr1 : expression .Expression , expr2 : expression .Expression ):
281
+ def _coerce_comparables (
282
+ expr1 : expression .Expression ,
283
+ expr2 : expression .Expression ,
284
+ * ,
285
+ bools_only : bool = False
286
+ ):
287
+ if bools_only :
288
+ if (
289
+ expr1 .output_type != dtypes .BOOL_DTYPE
290
+ and expr2 .output_type != dtypes .BOOL_DTYPE
291
+ ):
292
+ return expr1 , expr2
60
293
61
294
target_type = dtypes .coerce_to_common (expr1 .output_type , expr2 .output_type )
62
295
if expr1 .output_type != target_type :
@@ -90,7 +323,12 @@ def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression):
90
323
91
324
POLARS_LOWERING_RULES = (
92
325
* LOWER_COMPARISONS ,
326
+ LowerAddRule (),
327
+ LowerSubRule (),
328
+ LowerMulRule (),
329
+ LowerDivRule (),
93
330
LowerFloorDivRule (),
331
+ LowerModRule (),
94
332
)
95
333
96
334
0 commit comments