|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | | -from typing import TYPE_CHECKING, Any, Literal, cast, overload |
| 3 | +from typing import TYPE_CHECKING, Any, Callable, Literal, cast, overload |
4 | 4 |
|
5 | 5 | import pyarrow as pa |
6 | 6 | import pyarrow.compute as pc |
@@ -148,6 +148,16 @@ def _with_native( |
148 | 148 | result._broadcast = self._broadcast |
149 | 149 | return result |
150 | 150 |
|
| 151 | + def _with_binary(self, op: Callable[..., ArrayOrScalar], other: Any) -> Self: |
| 152 | + ser, other_native = extract_native(self, other) |
| 153 | + preserve_broadcast = self._broadcast and getattr(other, "_broadcast", True) |
| 154 | + return self._with_native( |
| 155 | + op(ser, other_native), preserve_broadcast=preserve_broadcast |
| 156 | + ).alias(self.name) |
| 157 | + |
| 158 | + def _with_binary_right(self, op: Callable[..., ArrayOrScalar], other: Any) -> Self: |
| 159 | + return self._with_binary(lambda x, y: op(y, x), other).alias(self.name) |
| 160 | + |
151 | 161 | @classmethod |
152 | 162 | def from_iterable( |
153 | 163 | cls, |
@@ -214,106 +224,89 @@ def __narwhals_namespace__(self) -> ArrowNamespace: |
214 | 224 | return ArrowNamespace(version=self._version) |
215 | 225 |
|
216 | 226 | def __eq__(self, other: object) -> Self: # type: ignore[override] |
217 | | - other = cast("PythonLiteral | ArrowSeries | None", other) |
218 | | - ser, rhs = extract_native(self, other) |
219 | | - return self._with_native(pc.equal(ser, rhs)) |
| 227 | + return self._with_binary(pc.equal, other) |
220 | 228 |
|
221 | 229 | def __ne__(self, other: object) -> Self: # type: ignore[override] |
222 | | - other = cast("PythonLiteral | ArrowSeries | None", other) |
223 | | - ser, rhs = extract_native(self, other) |
224 | | - return self._with_native(pc.not_equal(ser, rhs)) |
| 230 | + return self._with_binary(pc.not_equal, other) |
225 | 231 |
|
226 | 232 | def __ge__(self, other: Any) -> Self: |
227 | | - ser, other = extract_native(self, other) |
228 | | - return self._with_native(pc.greater_equal(ser, other)) |
| 233 | + return self._with_binary(pc.greater_equal, other) |
229 | 234 |
|
230 | 235 | def __gt__(self, other: Any) -> Self: |
231 | | - ser, other = extract_native(self, other) |
232 | | - return self._with_native(pc.greater(ser, other)) |
| 236 | + return self._with_binary(pc.greater, other) |
233 | 237 |
|
234 | 238 | def __le__(self, other: Any) -> Self: |
235 | | - ser, other = extract_native(self, other) |
236 | | - return self._with_native(pc.less_equal(ser, other)) |
| 239 | + return self._with_binary(pc.less_equal, other) |
237 | 240 |
|
238 | 241 | def __lt__(self, other: Any) -> Self: |
239 | | - ser, other = extract_native(self, other) |
240 | | - return self._with_native(pc.less(ser, other)) |
| 242 | + return self._with_binary(pc.less, other) |
241 | 243 |
|
242 | 244 | def __and__(self, other: Any) -> Self: |
243 | | - ser, other = extract_native(self, other) |
244 | | - return self._with_native(pc.and_kleene(ser, other)) # type: ignore[arg-type] |
| 245 | + return self._with_binary(pc.and_kleene, other) |
245 | 246 |
|
246 | 247 | def __rand__(self, other: Any) -> Self: |
247 | | - ser, other = extract_native(self, other) |
248 | | - return self._with_native(pc.and_kleene(other, ser)) # type: ignore[arg-type] |
| 248 | + return self._with_binary_right(pc.and_kleene, other) |
249 | 249 |
|
250 | 250 | def __or__(self, other: Any) -> Self: |
251 | | - ser, other = extract_native(self, other) |
252 | | - return self._with_native(pc.or_kleene(ser, other)) # type: ignore[arg-type] |
| 251 | + return self._with_binary_right(pc.or_kleene, other) |
253 | 252 |
|
254 | 253 | def __ror__(self, other: Any) -> Self: |
255 | | - ser, other = extract_native(self, other) |
256 | | - return self._with_native(pc.or_kleene(other, ser)) # type: ignore[arg-type] |
| 254 | + return self._with_binary_right(pc.or_kleene, other) |
257 | 255 |
|
258 | 256 | def __add__(self, other: Any) -> Self: |
259 | | - ser, other = extract_native(self, other) |
260 | | - return self._with_native(pc.add(ser, other)) |
| 257 | + return self._with_binary(pc.add, other) |
261 | 258 |
|
262 | 259 | def __radd__(self, other: Any) -> Self: |
263 | | - return self + other |
| 260 | + return self._with_binary_right(pc.add, other) |
264 | 261 |
|
265 | 262 | def __sub__(self, other: Any) -> Self: |
266 | | - ser, other = extract_native(self, other) |
267 | | - return self._with_native(pc.subtract(ser, other)) |
| 263 | + return self._with_binary(pc.subtract, other) |
268 | 264 |
|
269 | 265 | def __rsub__(self, other: Any) -> Self: |
270 | | - return (self - other) * (-1) |
| 266 | + return self._with_binary_right(pc.subtract, other) |
271 | 267 |
|
272 | 268 | def __mul__(self, other: Any) -> Self: |
273 | | - ser, other = extract_native(self, other) |
274 | | - return self._with_native(pc.multiply(ser, other)) |
| 269 | + return self._with_binary(pc.multiply, other) |
275 | 270 |
|
276 | 271 | def __rmul__(self, other: Any) -> Self: |
277 | | - return self * other |
| 272 | + return self._with_binary_right(pc.multiply, other) |
278 | 273 |
|
279 | 274 | def __pow__(self, other: Any) -> Self: |
280 | | - ser, other = extract_native(self, other) |
281 | | - return self._with_native(pc.power(ser, other)) |
| 275 | + return self._with_binary(pc.power, other) |
282 | 276 |
|
283 | 277 | def __rpow__(self, other: Any) -> Self: |
284 | | - ser, other = extract_native(self, other) |
285 | | - return self._with_native(pc.power(other, ser)) |
| 278 | + return self._with_binary_right(pc.power, other) |
286 | 279 |
|
287 | 280 | def __floordiv__(self, other: Any) -> Self: |
288 | | - ser, other = extract_native(self, other) |
289 | | - return self._with_native(floordiv_compat(ser, other)) |
| 281 | + return self._with_binary(floordiv_compat, other) |
290 | 282 |
|
291 | 283 | def __rfloordiv__(self, other: Any) -> Self: |
292 | | - ser, other = extract_native(self, other) |
293 | | - return self._with_native(floordiv_compat(other, ser)) |
| 284 | + return self._with_binary_right(floordiv_compat, other) |
294 | 285 |
|
295 | 286 | def __truediv__(self, other: Any) -> Self: |
296 | | - ser, other = extract_native(self, other) |
297 | | - return self._with_native(pc.divide(*cast_for_truediv(ser, other))) # type: ignore[type-var] |
| 287 | + return self._with_binary(lambda x, y: pc.divide(*cast_for_truediv(x, y)), other) |
298 | 288 |
|
299 | 289 | def __rtruediv__(self, other: Any) -> Self: |
300 | | - ser, other = extract_native(self, other) |
301 | | - return self._with_native(pc.divide(*cast_for_truediv(other, ser))) # type: ignore[type-var] |
| 290 | + return self._with_binary_right( |
| 291 | + lambda x, y: pc.divide(*cast_for_truediv(x, y)), other |
| 292 | + ) |
302 | 293 |
|
303 | 294 | def __mod__(self, other: Any) -> Self: |
| 295 | + preserve_broadcast = self._broadcast and getattr(other, "_broadcast", True) |
304 | 296 | floor_div = (self // other).native |
305 | 297 | ser, other = extract_native(self, other) |
306 | 298 | res = pc.subtract(ser, pc.multiply(floor_div, other)) |
307 | | - return self._with_native(res) |
| 299 | + return self._with_native(res, preserve_broadcast=preserve_broadcast) |
308 | 300 |
|
309 | 301 | def __rmod__(self, other: Any) -> Self: |
| 302 | + preserve_broadcast = self._broadcast and getattr(other, "_broadcast", True) |
310 | 303 | floor_div = (other // self).native |
311 | 304 | ser, other = extract_native(self, other) |
312 | 305 | res = pc.subtract(other, pc.multiply(floor_div, ser)) |
313 | | - return self._with_native(res) |
| 306 | + return self._with_native(res, preserve_broadcast=preserve_broadcast) |
314 | 307 |
|
315 | 308 | def __invert__(self) -> Self: |
316 | | - return self._with_native(pc.invert(self.native)) |
| 309 | + return self._with_native(pc.invert(self.native), preserve_broadcast=True) |
317 | 310 |
|
318 | 311 | @property |
319 | 312 | def _type(self) -> pa.DataType: |
|
0 commit comments