75
75
name = "ceil" , type_signature = op_typing .UNARY_REAL_NUMERIC
76
76
)
77
77
78
- abs_op = base_ops .create_unary_op (name = "abs" , type_signature = op_typing .UNARY_NUMERIC )
78
+ abs_op = base_ops .create_unary_op (
79
+ name = "abs" , type_signature = op_typing .UNARY_NUMERIC_AND_TIMEDELTA
80
+ )
79
81
80
- pos_op = base_ops .create_unary_op (name = "pos" , type_signature = op_typing .UNARY_NUMERIC )
82
+ pos_op = base_ops .create_unary_op (
83
+ name = "pos" , type_signature = op_typing .UNARY_NUMERIC_AND_TIMEDELTA
84
+ )
81
85
82
- neg_op = base_ops .create_unary_op (name = "neg" , type_signature = op_typing .UNARY_NUMERIC )
86
+ neg_op = base_ops .create_unary_op (
87
+ name = "neg" , type_signature = op_typing .UNARY_NUMERIC_AND_TIMEDELTA
88
+ )
83
89
84
90
exp_op = base_ops .create_unary_op (
85
91
name = "exp" , type_signature = op_typing .UNARY_REAL_NUMERIC
@@ -123,6 +129,9 @@ def output_type(self, *input_types):
123
129
if left_type is dtypes .TIMEDELTA_DTYPE and dtypes .is_datetime_like (right_type ):
124
130
return right_type
125
131
132
+ if left_type is dtypes .TIMEDELTA_DTYPE and right_type is dtypes .TIMEDELTA_DTYPE :
133
+ return dtypes .TIMEDELTA_DTYPE
134
+
126
135
if (left_type is None or dtypes .is_numeric (left_type )) and (
127
136
right_type is None or dtypes .is_numeric (right_type )
128
137
):
@@ -142,32 +151,102 @@ class SubOp(base_ops.BinaryOp):
142
151
def output_type (self , * input_types ):
143
152
left_type = input_types [0 ]
144
153
right_type = input_types [1 ]
145
- if (left_type is None or dtypes .is_numeric (left_type )) and (
146
- right_type is None or dtypes .is_numeric (right_type )
147
- ):
148
- # Numeric subtraction
149
- return dtypes .coerce_to_common (left_type , right_type )
150
154
151
155
if dtypes .is_datetime_like (left_type ) and dtypes .is_datetime_like (right_type ):
152
156
return dtypes .TIMEDELTA_DTYPE
153
157
154
158
if dtypes .is_datetime_like (left_type ) and right_type is dtypes .TIMEDELTA_DTYPE :
155
159
return left_type
156
160
161
+ if left_type is dtypes .TIMEDELTA_DTYPE and right_type is dtypes .TIMEDELTA_DTYPE :
162
+ return dtypes .TIMEDELTA_DTYPE
163
+
164
+ if (left_type is None or dtypes .is_numeric (left_type )) and (
165
+ right_type is None or dtypes .is_numeric (right_type )
166
+ ):
167
+ # Numeric subtraction
168
+ return dtypes .coerce_to_common (left_type , right_type )
169
+
157
170
raise TypeError (f"Cannot subtract dtypes { left_type } and { right_type } " )
158
171
159
172
160
173
sub_op = SubOp ()
161
174
162
- mul_op = base_ops .create_binary_op (name = "mul" , type_signature = op_typing .BINARY_NUMERIC )
163
175
164
- div_op = base_ops . create_binary_op (
165
- name = "div" , type_signature = op_typing . BINARY_REAL_NUMERIC
166
- )
176
+ @ dataclasses . dataclass ( frozen = True )
177
+ class MulOp ( base_ops . BinaryOp ):
178
+ name : typing . ClassVar [ str ] = "mul"
167
179
168
- floordiv_op = base_ops .create_binary_op (
169
- name = "floordiv" , type_signature = op_typing .BINARY_NUMERIC
170
- )
180
+ def output_type (self , * input_types : dtypes .ExpressionType ) -> dtypes .ExpressionType :
181
+ left_type = input_types [0 ]
182
+ right_type = input_types [1 ]
183
+
184
+ if left_type is dtypes .TIMEDELTA_DTYPE and dtypes .is_numeric (right_type ):
185
+ return dtypes .TIMEDELTA_DTYPE
186
+ if dtypes .is_numeric (left_type ) and right_type is dtypes .TIMEDELTA_DTYPE :
187
+ return dtypes .TIMEDELTA_DTYPE
188
+
189
+ if (left_type is None or dtypes .is_numeric (left_type )) and (
190
+ right_type is None or dtypes .is_numeric (right_type )
191
+ ):
192
+ return dtypes .coerce_to_common (left_type , right_type )
193
+
194
+ raise TypeError (f"Cannot multiply dtypes { left_type } and { right_type } " )
195
+
196
+
197
+ mul_op = MulOp ()
198
+
199
+
200
+ @dataclasses .dataclass (frozen = True )
201
+ class DivOp (base_ops .BinaryOp ):
202
+ name : typing .ClassVar [str ] = "div"
203
+
204
+ def output_type (self , * input_types : dtypes .ExpressionType ) -> dtypes .ExpressionType :
205
+ left_type = input_types [0 ]
206
+ right_type = input_types [1 ]
207
+
208
+ if left_type is dtypes .TIMEDELTA_DTYPE and dtypes .is_numeric (right_type ):
209
+ return dtypes .TIMEDELTA_DTYPE
210
+
211
+ if left_type is dtypes .TIMEDELTA_DTYPE and right_type is dtypes .TIMEDELTA_DTYPE :
212
+ return dtypes .FLOAT_DTYPE
213
+
214
+ if (left_type is None or dtypes .is_numeric (left_type )) and (
215
+ right_type is None or dtypes .is_numeric (right_type )
216
+ ):
217
+ lcd_type = dtypes .coerce_to_common (left_type , right_type )
218
+ # Real numeric ops produce floats on int input
219
+ return dtypes .FLOAT_DTYPE if lcd_type == dtypes .INT_DTYPE else lcd_type
220
+
221
+ raise TypeError (f"Cannot divide dtypes { left_type } and { right_type } " )
222
+
223
+
224
+ div_op = DivOp ()
225
+
226
+
227
+ @dataclasses .dataclass (frozen = True )
228
+ class FloorDivOp (base_ops .BinaryOp ):
229
+ name : typing .ClassVar [str ] = "floordiv"
230
+
231
+ def output_type (self , * input_types : dtypes .ExpressionType ) -> dtypes .ExpressionType :
232
+ left_type = input_types [0 ]
233
+ right_type = input_types [1 ]
234
+
235
+ if left_type is dtypes .TIMEDELTA_DTYPE and dtypes .is_numeric (right_type ):
236
+ return dtypes .TIMEDELTA_DTYPE
237
+
238
+ if left_type is dtypes .TIMEDELTA_DTYPE and right_type is dtypes .TIMEDELTA_DTYPE :
239
+ return dtypes .INT_DTYPE
240
+
241
+ if (left_type is None or dtypes .is_numeric (left_type )) and (
242
+ right_type is None or dtypes .is_numeric (right_type )
243
+ ):
244
+ return dtypes .coerce_to_common (left_type , right_type )
245
+
246
+ raise TypeError (f"Cannot floor divide dtypes { left_type } and { right_type } " )
247
+
248
+
249
+ floordiv_op = FloorDivOp ()
171
250
172
251
pow_op = base_ops .create_binary_op (name = "pow" , type_signature = op_typing .BINARY_NUMERIC )
173
252
0 commit comments