@@ -125,6 +125,9 @@ def _rewrite_op_expr(
125125 # but for timedeltas: int(timedelta) // float => int(timedelta)
126126 return _rewrite_floordiv_op (inputs [0 ], inputs [1 ])
127127
128+ if isinstance (expr .op , ops .ToTimedeltaOp ):
129+ return _rewrite_to_timedelta_op (expr .op , inputs [0 ])
130+
128131 return _TypedExpr .create_op_expr (expr .op , * inputs )
129132
130133
@@ -154,9 +157,9 @@ def _rewrite_mul_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr:
154157 result = _TypedExpr .create_op_expr (ops .mul_op , left , right )
155158
156159 if left .dtype is dtypes .TIMEDELTA_DTYPE and dtypes .is_numeric (right .dtype ):
157- return _TypedExpr .create_op_expr (ops .ToTimedeltaOp ( "us" ) , result )
160+ return _TypedExpr .create_op_expr (ops .timedelta_floor_op , result )
158161 if dtypes .is_numeric (left .dtype ) and right .dtype is dtypes .TIMEDELTA_DTYPE :
159- return _TypedExpr .create_op_expr (ops .ToTimedeltaOp ( "us" ) , result )
162+ return _TypedExpr .create_op_expr (ops .timedelta_floor_op , result )
160163
161164 return result
162165
@@ -165,7 +168,7 @@ def _rewrite_div_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr:
165168 result = _TypedExpr .create_op_expr (ops .div_op , left , right )
166169
167170 if left .dtype is dtypes .TIMEDELTA_DTYPE and dtypes .is_numeric (right .dtype ):
168- return _TypedExpr .create_op_expr (ops .ToTimedeltaOp ( "us" ) , result )
171+ return _TypedExpr .create_op_expr (ops .timedelta_floor_op , result )
169172
170173 return result
171174
@@ -174,11 +177,19 @@ def _rewrite_floordiv_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr:
174177 result = _TypedExpr .create_op_expr (ops .floordiv_op , left , right )
175178
176179 if left .dtype is dtypes .TIMEDELTA_DTYPE and dtypes .is_numeric (right .dtype ):
177- return _TypedExpr .create_op_expr (ops .ToTimedeltaOp ( "us" ) , result )
180+ return _TypedExpr .create_op_expr (ops .timedelta_floor_op , result )
178181
179182 return result
180183
181184
185+ def _rewrite_to_timedelta_op (op : ops .ToTimedeltaOp , arg : _TypedExpr ):
186+ if arg .dtype is dtypes .TIMEDELTA_DTYPE :
187+ # Do nothing for values that are already timedeltas
188+ return arg
189+
190+ return _TypedExpr .create_op_expr (op , arg )
191+
192+
182193@functools .cache
183194def _rewrite_aggregation (
184195 aggregation : ex .Aggregation , schema : schema .ArraySchema
0 commit comments