Skip to content

Commit ba1da16

Browse files
authored
chore: simplify dask binary ops (#2992)
* chore: simplify dask * typing
1 parent 81d272c commit ba1da16

File tree

1 file changed

+63
-76
lines changed

1 file changed

+63
-76
lines changed

narwhals/_dask/expr.py

Lines changed: 63 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -186,105 +186,92 @@ def _with_alias_output_names(self, func: AliasNames | None, /) -> Self:
186186
scalar_kwargs=self._scalar_kwargs,
187187
)
188188

189-
def __add__(self, other: Any) -> Self:
190-
return self._with_callable(
191-
lambda expr, other: expr.__add__(other), "__add__", other=other
189+
def _with_binary(
190+
self,
191+
call: Callable[[dx.Series, Any], dx.Series],
192+
name: str,
193+
other: Any,
194+
*,
195+
reverse: bool = False,
196+
) -> Self:
197+
result = self._with_callable(
198+
lambda expr, other: call(expr, other), name, other=other
192199
)
200+
if reverse:
201+
result = result.alias("literal")
202+
return result
193203

194-
def __sub__(self, other: Any) -> Self:
195-
return self._with_callable(
196-
lambda expr, other: expr.__sub__(other), "__sub__", other=other
204+
def _binary_op(self, op_name: str, other: Any) -> Self:
205+
return self._with_binary(
206+
lambda expr, other: getattr(expr, op_name)(other), op_name, other
197207
)
198208

199-
def __rsub__(self, other: Any) -> Self:
200-
return self._with_callable(
201-
lambda expr, other: other - expr, "__rsub__", other=other
202-
).alias("literal")
209+
def _reverse_binary_op(
210+
self, op_name: str, operator_func: Callable[..., dx.Series], other: Any
211+
) -> Self:
212+
return self._with_binary(
213+
lambda expr, other: operator_func(other, expr), op_name, other, reverse=True
214+
)
215+
216+
def __add__(self, other: Any) -> Self:
217+
return self._binary_op("__add__", other)
218+
219+
def __sub__(self, other: Any) -> Self:
220+
return self._binary_op("__sub__", other)
203221

204222
def __mul__(self, other: Any) -> Self:
205-
return self._with_callable(
206-
lambda expr, other: expr.__mul__(other), "__mul__", other=other
207-
)
223+
return self._binary_op("__mul__", other)
208224

209225
def __truediv__(self, other: Any) -> Self:
210-
return self._with_callable(
211-
lambda expr, other: expr.__truediv__(other), "__truediv__", other=other
212-
)
213-
214-
def __rtruediv__(self, other: Any) -> Self:
215-
return self._with_callable(
216-
lambda expr, other: other / expr, "__rtruediv__", other=other
217-
).alias("literal")
226+
return self._binary_op("__truediv__", other)
218227

219228
def __floordiv__(self, other: Any) -> Self:
220-
return self._with_callable(
221-
lambda expr, other: expr.__floordiv__(other), "__floordiv__", other=other
222-
)
223-
224-
def __rfloordiv__(self, other: Any) -> Self:
225-
return self._with_callable(
226-
lambda expr, other: other // expr, "__rfloordiv__", other=other
227-
).alias("literal")
229+
return self._binary_op("__floordiv__", other)
228230

229231
def __pow__(self, other: Any) -> Self:
230-
return self._with_callable(
231-
lambda expr, other: expr.__pow__(other), "__pow__", other=other
232-
)
233-
234-
def __rpow__(self, other: Any) -> Self:
235-
return self._with_callable(
236-
lambda expr, other: other**expr, "__rpow__", other=other
237-
).alias("literal")
232+
return self._binary_op("__pow__", other)
238233

239234
def __mod__(self, other: Any) -> Self:
240-
return self._with_callable(
241-
lambda expr, other: expr.__mod__(other), "__mod__", other=other
242-
)
235+
return self._binary_op("__mod__", other)
243236

244-
def __rmod__(self, other: Any) -> Self:
245-
return self._with_callable(
246-
lambda expr, other: other % expr, "__rmod__", other=other
247-
).alias("literal")
237+
def __eq__(self, other: object) -> Self: # type: ignore[override]
238+
return self._binary_op("__eq__", other)
248239

249-
def __eq__(self, other: DaskExpr) -> Self: # type: ignore[override]
250-
return self._with_callable(
251-
lambda expr, other: expr.__eq__(other), "__eq__", other=other
252-
)
240+
def __ne__(self, other: object) -> Self: # type: ignore[override]
241+
return self._binary_op("__ne__", other)
253242

254-
def __ne__(self, other: DaskExpr) -> Self: # type: ignore[override]
255-
return self._with_callable(
256-
lambda expr, other: expr.__ne__(other), "__ne__", other=other
257-
)
243+
def __ge__(self, other: Any) -> Self:
244+
return self._binary_op("__ge__", other)
258245

259-
def __ge__(self, other: DaskExpr | Any) -> Self:
260-
return self._with_callable(
261-
lambda expr, other: expr.__ge__(other), "__ge__", other=other
262-
)
246+
def __gt__(self, other: Any) -> Self:
247+
return self._binary_op("__gt__", other)
263248

264-
def __gt__(self, other: DaskExpr) -> Self:
265-
return self._with_callable(
266-
lambda expr, other: expr.__gt__(other), "__gt__", other=other
267-
)
249+
def __le__(self, other: Any) -> Self:
250+
return self._binary_op("__le__", other)
268251

269-
def __le__(self, other: DaskExpr) -> Self:
270-
return self._with_callable(
271-
lambda expr, other: expr.__le__(other), "__le__", other=other
272-
)
252+
def __lt__(self, other: Any) -> Self:
253+
return self._binary_op("__lt__", other)
273254

274-
def __lt__(self, other: DaskExpr) -> Self:
275-
return self._with_callable(
276-
lambda expr, other: expr.__lt__(other), "__lt__", other=other
277-
)
255+
def __and__(self, other: Any) -> Self:
256+
return self._binary_op("__and__", other)
278257

279-
def __and__(self, other: DaskExpr | Any) -> Self:
280-
return self._with_callable(
281-
lambda expr, other: expr.__and__(other), "__and__", other=other
282-
)
258+
def __or__(self, other: Any) -> Self:
259+
return self._binary_op("__or__", other)
283260

284-
def __or__(self, other: DaskExpr) -> Self:
285-
return self._with_callable(
286-
lambda expr, other: expr.__or__(other), "__or__", other=other
287-
)
261+
def __rsub__(self, other: Any) -> Self:
262+
return self._reverse_binary_op("__rsub__", lambda a, b: a - b, other)
263+
264+
def __rtruediv__(self, other: Any) -> Self:
265+
return self._reverse_binary_op("__rtruediv__", lambda a, b: a / b, other)
266+
267+
def __rfloordiv__(self, other: Any) -> Self:
268+
return self._reverse_binary_op("__rfloordiv__", lambda a, b: a // b, other)
269+
270+
def __rpow__(self, other: Any) -> Self:
271+
return self._reverse_binary_op("__rpow__", lambda a, b: a**b, other)
272+
273+
def __rmod__(self, other: Any) -> Self:
274+
return self._reverse_binary_op("__rmod__", lambda a, b: a % b, other)
288275

289276
def __invert__(self) -> Self:
290277
return self._with_callable(lambda expr: expr.__invert__(), "__invert__")

0 commit comments

Comments
 (0)