|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | | -from functools import partial |
4 | 3 | import operator |
5 | 4 | import re |
6 | 5 | from typing import ( |
@@ -216,12 +215,17 @@ def dtype(self) -> StringDtype: # type: ignore[override] |
216 | 215 | return self._dtype |
217 | 216 |
|
218 | 217 | def insert(self, loc: int, item) -> ArrowStringArray: |
| 218 | + if self.dtype.na_value is np.nan and item is np.nan: |
| 219 | + item = libmissing.NA |
219 | 220 | if not isinstance(item, str) and item is not libmissing.NA: |
220 | 221 | raise TypeError("Scalar must be NA or str") |
221 | 222 | return super().insert(loc, item) |
222 | 223 |
|
223 | | - @classmethod |
224 | | - def _result_converter(cls, values, na=None): |
| 224 | + def _result_converter(self, values, na=None): |
| 225 | + if self.dtype.na_value is np.nan: |
| 226 | + if not isna(na): |
| 227 | + values = values.fill_null(bool(na)) |
| 228 | + return ArrowExtensionArray(values).to_numpy(na_value=np.nan) |
225 | 229 | return BooleanDtype().__from_arrow__(values) |
226 | 230 |
|
227 | 231 | def _maybe_convert_setitem_value(self, value): |
@@ -492,11 +496,30 @@ def _str_get_dummies(self, sep: str = "|"): |
492 | 496 | return dummies.astype(np.int64, copy=False), labels |
493 | 497 |
|
494 | 498 | def _convert_int_dtype(self, result): |
| 499 | + if self.dtype.na_value is np.nan: |
| 500 | + if isinstance(result, pa.Array): |
| 501 | + result = result.to_numpy(zero_copy_only=False) |
| 502 | + else: |
| 503 | + result = result.to_numpy() |
| 504 | + if result.dtype == np.int32: |
| 505 | + result = result.astype(np.int64) |
| 506 | + return result |
| 507 | + |
495 | 508 | return Int64Dtype().__from_arrow__(result) |
496 | 509 |
|
497 | 510 | def _reduce( |
498 | 511 | self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs |
499 | 512 | ): |
| 513 | + if self.dtype.na_value is np.nan and name in ["any", "all"]: |
| 514 | + if not skipna: |
| 515 | + nas = pc.is_null(self._pa_array) |
| 516 | + arr = pc.or_kleene(nas, pc.not_equal(self._pa_array, "")) |
| 517 | + else: |
| 518 | + arr = pc.not_equal(self._pa_array, "") |
| 519 | + return ArrowExtensionArray(arr)._reduce( |
| 520 | + name, skipna=skipna, keepdims=keepdims, **kwargs |
| 521 | + ) |
| 522 | + |
500 | 523 | result = self._reduce_calc(name, skipna=skipna, keepdims=keepdims, **kwargs) |
501 | 524 | if name in ("argmin", "argmax") and isinstance(result, pa.Array): |
502 | 525 | return self._convert_int_dtype(result) |
@@ -527,67 +550,31 @@ def _rank( |
527 | 550 | ) |
528 | 551 | ) |
529 | 552 |
|
530 | | - |
531 | | -class ArrowStringArrayNumpySemantics(ArrowStringArray): |
532 | | - _storage = "pyarrow" |
533 | | - _na_value = np.nan |
534 | | - |
535 | | - @classmethod |
536 | | - def _result_converter(cls, values, na=None): |
537 | | - if not isna(na): |
538 | | - values = values.fill_null(bool(na)) |
539 | | - return ArrowExtensionArray(values).to_numpy(na_value=np.nan) |
540 | | - |
541 | | - def __getattribute__(self, item): |
542 | | - # ArrowStringArray and we both inherit from ArrowExtensionArray, which |
543 | | - # creates inheritance problems (Diamond inheritance) |
544 | | - if item in ArrowStringArrayMixin.__dict__ and item not in ( |
545 | | - "_pa_array", |
546 | | - "__dict__", |
547 | | - ): |
548 | | - return partial(getattr(ArrowStringArrayMixin, item), self) |
549 | | - return super().__getattribute__(item) |
550 | | - |
551 | | - def _convert_int_dtype(self, result): |
552 | | - if isinstance(result, pa.Array): |
553 | | - result = result.to_numpy(zero_copy_only=False) |
554 | | - else: |
555 | | - result = result.to_numpy() |
556 | | - if result.dtype == np.int32: |
557 | | - result = result.astype(np.int64) |
| 553 | + def value_counts(self, dropna: bool = True) -> Series: |
| 554 | + result = super().value_counts(dropna=dropna) |
| 555 | + if self.dtype.na_value is np.nan: |
| 556 | + res_values = result._values.to_numpy() |
| 557 | + return result._constructor( |
| 558 | + res_values, index=result.index, name=result.name, copy=False |
| 559 | + ) |
558 | 560 | return result |
559 | 561 |
|
560 | 562 | def _cmp_method(self, other, op): |
561 | 563 | result = super()._cmp_method(other, op) |
562 | | - if op == operator.ne: |
563 | | - return result.to_numpy(np.bool_, na_value=True) |
564 | | - else: |
565 | | - return result.to_numpy(np.bool_, na_value=False) |
566 | | - |
567 | | - def value_counts(self, dropna: bool = True) -> Series: |
568 | | - from pandas import Series |
569 | | - |
570 | | - result = super().value_counts(dropna) |
571 | | - return Series( |
572 | | - result._values.to_numpy(), index=result.index, name=result.name, copy=False |
573 | | - ) |
574 | | - |
575 | | - def _reduce( |
576 | | - self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs |
577 | | - ): |
578 | | - if name in ["any", "all"]: |
579 | | - if not skipna: |
580 | | - nas = pc.is_null(self._pa_array) |
581 | | - arr = pc.or_kleene(nas, pc.not_equal(self._pa_array, "")) |
| 564 | + if self.dtype.na_value is np.nan: |
| 565 | + if op == operator.ne: |
| 566 | + return result.to_numpy(np.bool_, na_value=True) |
582 | 567 | else: |
583 | | - arr = pc.not_equal(self._pa_array, "") |
584 | | - return ArrowExtensionArray(arr)._reduce( |
585 | | - name, skipna=skipna, keepdims=keepdims, **kwargs |
586 | | - ) |
587 | | - else: |
588 | | - return super()._reduce(name, skipna=skipna, keepdims=keepdims, **kwargs) |
| 568 | + return result.to_numpy(np.bool_, na_value=False) |
| 569 | + return result |
589 | 570 |
|
590 | | - def insert(self, loc: int, item) -> ArrowStringArrayNumpySemantics: |
591 | | - if item is np.nan: |
592 | | - item = libmissing.NA |
593 | | - return super().insert(loc, item) # type: ignore[return-value] |
| 571 | + |
| 572 | +class ArrowStringArrayNumpySemantics(ArrowStringArray): |
| 573 | + _na_value = np.nan |
| 574 | + _str_get = ArrowStringArrayMixin._str_get |
| 575 | + _str_removesuffix = ArrowStringArrayMixin._str_removesuffix |
| 576 | + _str_capitalize = ArrowStringArrayMixin._str_capitalize |
| 577 | + _str_pad = ArrowStringArrayMixin._str_pad |
| 578 | + _str_title = ArrowStringArrayMixin._str_title |
| 579 | + _str_swapcase = ArrowStringArrayMixin._str_swapcase |
| 580 | + _str_slice_replace = ArrowStringArrayMixin._str_slice_replace |
0 commit comments