Skip to content

Commit ffa63d4

Browse files
authored
chore: support operators between timedeltas (#1396)
* support operators within timedeltas * fix mypy * fix tests * fix format * add even more operators to stay away from floats * Make code slimmer.
1 parent 3cee24b commit ffa63d4

File tree

7 files changed

+251
-19
lines changed

7 files changed

+251
-19
lines changed

bigframes/core/rewrite/timedeltas.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,17 @@ def _rewrite_op_expr(
103103
if isinstance(expr.op, ops.AddOp):
104104
return _rewrite_add_op(inputs[0], inputs[1])
105105

106+
if isinstance(expr.op, ops.MulOp):
107+
return _rewrite_mul_op(inputs[0], inputs[1])
108+
109+
if isinstance(expr.op, ops.DivOp):
110+
return _rewrite_div_op(inputs[0], inputs[1])
111+
112+
if isinstance(expr.op, ops.FloorDivOp):
113+
# We need to re-write floor div because for numerics: int // float => float
114+
# but for timedeltas: int(timedelta) // float => int(timedelta)
115+
return _rewrite_floordiv_op(inputs[0], inputs[1])
116+
106117
return _TypedExpr.create_op_expr(expr.op, *inputs)
107118

108119

@@ -126,3 +137,32 @@ def _rewrite_add_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr:
126137
return _TypedExpr.create_op_expr(ops.timestamp_add_op, right, left)
127138

128139
return _TypedExpr.create_op_expr(ops.add_op, left, right)
140+
141+
142+
def _rewrite_mul_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr:
143+
result = _TypedExpr.create_op_expr(ops.mul_op, left, right)
144+
145+
if left.dtype is dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right.dtype):
146+
return _TypedExpr.create_op_expr(ops.ToTimedeltaOp("us"), result)
147+
if dtypes.is_numeric(left.dtype) and right.dtype is dtypes.TIMEDELTA_DTYPE:
148+
return _TypedExpr.create_op_expr(ops.ToTimedeltaOp("us"), result)
149+
150+
return result
151+
152+
153+
def _rewrite_div_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr:
154+
result = _TypedExpr.create_op_expr(ops.div_op, left, right)
155+
156+
if left.dtype is dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right.dtype):
157+
return _TypedExpr.create_op_expr(ops.ToTimedeltaOp("us"), result)
158+
159+
return result
160+
161+
162+
def _rewrite_floordiv_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr:
163+
result = _TypedExpr.create_op_expr(ops.floordiv_op, left, right)
164+
165+
if left.dtype is dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right.dtype):
166+
return _TypedExpr.create_op_expr(ops.ToTimedeltaOp("us"), result)
167+
168+
return result

bigframes/operations/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,15 +115,18 @@
115115
cos_op,
116116
cosh_op,
117117
div_op,
118+
DivOp,
118119
exp_op,
119120
expm1_op,
120121
floor_op,
121122
floordiv_op,
123+
FloorDivOp,
122124
ln_op,
123125
log1p_op,
124126
log10_op,
125127
mod_op,
126128
mul_op,
129+
MulOp,
127130
neg_op,
128131
pos_op,
129132
pow_op,
@@ -282,15 +285,18 @@
282285
"cos_op",
283286
"cosh_op",
284287
"div_op",
288+
"DivOp",
285289
"exp_op",
286290
"expm1_op",
287291
"floor_op",
288292
"floordiv_op",
293+
"FloorDivOp",
289294
"ln_op",
290295
"log1p_op",
291296
"log10_op",
292297
"mod_op",
293298
"mul_op",
299+
"MulOp",
294300
"neg_op",
295301
"pos_op",
296302
"pow_op",

bigframes/operations/numeric_ops.py

Lines changed: 94 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,17 @@
7575
name="ceil", type_signature=op_typing.UNARY_REAL_NUMERIC
7676
)
7777

78-
abs_op = base_ops.create_unary_op(name="abs", type_signature=op_typing.UNARY_NUMERIC)
78+
abs_op = base_ops.create_unary_op(
79+
name="abs", type_signature=op_typing.UNARY_NUMERIC_AND_TIMEDELTA
80+
)
7981

80-
pos_op = base_ops.create_unary_op(name="pos", type_signature=op_typing.UNARY_NUMERIC)
82+
pos_op = base_ops.create_unary_op(
83+
name="pos", type_signature=op_typing.UNARY_NUMERIC_AND_TIMEDELTA
84+
)
8185

82-
neg_op = base_ops.create_unary_op(name="neg", type_signature=op_typing.UNARY_NUMERIC)
86+
neg_op = base_ops.create_unary_op(
87+
name="neg", type_signature=op_typing.UNARY_NUMERIC_AND_TIMEDELTA
88+
)
8389

8490
exp_op = base_ops.create_unary_op(
8591
name="exp", type_signature=op_typing.UNARY_REAL_NUMERIC
@@ -123,6 +129,9 @@ def output_type(self, *input_types):
123129
if left_type is dtypes.TIMEDELTA_DTYPE and dtypes.is_datetime_like(right_type):
124130
return right_type
125131

132+
if left_type is dtypes.TIMEDELTA_DTYPE and right_type is dtypes.TIMEDELTA_DTYPE:
133+
return dtypes.TIMEDELTA_DTYPE
134+
126135
if (left_type is None or dtypes.is_numeric(left_type)) and (
127136
right_type is None or dtypes.is_numeric(right_type)
128137
):
@@ -142,32 +151,102 @@ class SubOp(base_ops.BinaryOp):
142151
def output_type(self, *input_types):
143152
left_type = input_types[0]
144153
right_type = input_types[1]
145-
if (left_type is None or dtypes.is_numeric(left_type)) and (
146-
right_type is None or dtypes.is_numeric(right_type)
147-
):
148-
# Numeric subtraction
149-
return dtypes.coerce_to_common(left_type, right_type)
150154

151155
if dtypes.is_datetime_like(left_type) and dtypes.is_datetime_like(right_type):
152156
return dtypes.TIMEDELTA_DTYPE
153157

154158
if dtypes.is_datetime_like(left_type) and right_type is dtypes.TIMEDELTA_DTYPE:
155159
return left_type
156160

161+
if left_type is dtypes.TIMEDELTA_DTYPE and right_type is dtypes.TIMEDELTA_DTYPE:
162+
return dtypes.TIMEDELTA_DTYPE
163+
164+
if (left_type is None or dtypes.is_numeric(left_type)) and (
165+
right_type is None or dtypes.is_numeric(right_type)
166+
):
167+
# Numeric subtraction
168+
return dtypes.coerce_to_common(left_type, right_type)
169+
157170
raise TypeError(f"Cannot subtract dtypes {left_type} and {right_type}")
158171

159172

160173
sub_op = SubOp()
161174

162-
mul_op = base_ops.create_binary_op(name="mul", type_signature=op_typing.BINARY_NUMERIC)
163175

164-
div_op = base_ops.create_binary_op(
165-
name="div", type_signature=op_typing.BINARY_REAL_NUMERIC
166-
)
176+
@dataclasses.dataclass(frozen=True)
177+
class MulOp(base_ops.BinaryOp):
178+
name: typing.ClassVar[str] = "mul"
167179

168-
floordiv_op = base_ops.create_binary_op(
169-
name="floordiv", type_signature=op_typing.BINARY_NUMERIC
170-
)
180+
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
181+
left_type = input_types[0]
182+
right_type = input_types[1]
183+
184+
if left_type is dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right_type):
185+
return dtypes.TIMEDELTA_DTYPE
186+
if dtypes.is_numeric(left_type) and right_type is dtypes.TIMEDELTA_DTYPE:
187+
return dtypes.TIMEDELTA_DTYPE
188+
189+
if (left_type is None or dtypes.is_numeric(left_type)) and (
190+
right_type is None or dtypes.is_numeric(right_type)
191+
):
192+
return dtypes.coerce_to_common(left_type, right_type)
193+
194+
raise TypeError(f"Cannot multiply dtypes {left_type} and {right_type}")
195+
196+
197+
mul_op = MulOp()
198+
199+
200+
@dataclasses.dataclass(frozen=True)
201+
class DivOp(base_ops.BinaryOp):
202+
name: typing.ClassVar[str] = "div"
203+
204+
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
205+
left_type = input_types[0]
206+
right_type = input_types[1]
207+
208+
if left_type is dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right_type):
209+
return dtypes.TIMEDELTA_DTYPE
210+
211+
if left_type is dtypes.TIMEDELTA_DTYPE and right_type is dtypes.TIMEDELTA_DTYPE:
212+
return dtypes.FLOAT_DTYPE
213+
214+
if (left_type is None or dtypes.is_numeric(left_type)) and (
215+
right_type is None or dtypes.is_numeric(right_type)
216+
):
217+
lcd_type = dtypes.coerce_to_common(left_type, right_type)
218+
# Real numeric ops produce floats on int input
219+
return dtypes.FLOAT_DTYPE if lcd_type == dtypes.INT_DTYPE else lcd_type
220+
221+
raise TypeError(f"Cannot divide dtypes {left_type} and {right_type}")
222+
223+
224+
div_op = DivOp()
225+
226+
227+
@dataclasses.dataclass(frozen=True)
228+
class FloorDivOp(base_ops.BinaryOp):
229+
name: typing.ClassVar[str] = "floordiv"
230+
231+
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
232+
left_type = input_types[0]
233+
right_type = input_types[1]
234+
235+
if left_type is dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right_type):
236+
return dtypes.TIMEDELTA_DTYPE
237+
238+
if left_type is dtypes.TIMEDELTA_DTYPE and right_type is dtypes.TIMEDELTA_DTYPE:
239+
return dtypes.INT_DTYPE
240+
241+
if (left_type is None or dtypes.is_numeric(left_type)) and (
242+
right_type is None or dtypes.is_numeric(right_type)
243+
):
244+
return dtypes.coerce_to_common(left_type, right_type)
245+
246+
raise TypeError(f"Cannot floor divide dtypes {left_type} and {right_type}")
247+
248+
249+
floordiv_op = FloorDivOp()
171250

172251
pow_op = base_ops.create_binary_op(name="pow", type_signature=op_typing.BINARY_NUMERIC)
173252

bigframes/operations/timedelta_ops.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,11 @@ class ToTimedeltaOp(base_ops.UnaryOp):
2626
unit: typing.Literal["us", "ms", "s", "m", "h", "d", "W"]
2727

2828
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
29-
if input_types[0] in (dtypes.INT_DTYPE, dtypes.FLOAT_DTYPE):
29+
if input_types[0] in (
30+
dtypes.INT_DTYPE,
31+
dtypes.FLOAT_DTYPE,
32+
dtypes.TIMEDELTA_DTYPE,
33+
):
3034
return dtypes.TIMEDELTA_DTYPE
3135
raise TypeError("expected integer or float input")
3236

@@ -56,7 +60,6 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
5660
timestamp_add_op = TimestampAdd()
5761

5862

59-
@dataclasses.dataclass(frozen=True)
6063
class TimestampSub(base_ops.BinaryOp):
6164
name: typing.ClassVar[str] = "timestamp_sub"
6265

bigframes/operations/type.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,10 @@ def output_type(
224224

225225
# Common type signatures
226226
UNARY_NUMERIC = TypePreserving(bigframes.dtypes.is_numeric, description="numeric")
227+
UNARY_NUMERIC_AND_TIMEDELTA = TypePreserving(
228+
lambda x: bigframes.dtypes.is_numeric(x) or x is bigframes.dtypes.TIMEDELTA_DTYPE,
229+
description="numeric_and_timedelta",
230+
)
227231
UNARY_REAL_NUMERIC = UnaryRealNumeric()
228232
BINARY_NUMERIC = BinaryNumeric()
229233
BINARY_REAL_NUMERIC = BinaryRealNumeric()

bigframes/series.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -964,6 +964,9 @@ def update(self, other: Union[Series, Sequence, Mapping]) -> None:
964964
)
965965
self._set_block(result._get_block())
966966

967+
def __abs__(self) -> Series:
968+
return self.abs()
969+
967970
def abs(self) -> Series:
968971
return self._apply_unary_op(ops.abs_op)
969972

0 commit comments

Comments
 (0)