Skip to content

Commit 63a7fc5

Browse files
String dtype: implement object-dtype based StringArray variant with NumPy semantics
1 parent 6320c8b commit 63a7fc5

File tree

13 files changed

+220
-77
lines changed

13 files changed

+220
-77
lines changed

pandas/_libs/lib.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2692,7 +2692,7 @@ def maybe_convert_objects(ndarray[object] objects,
26922692
if using_pyarrow_string_dtype() and is_string_array(objects, skipna=True):
26932693
from pandas.core.arrays.string_ import StringDtype
26942694

2695-
dtype = StringDtype(storage="pyarrow_numpy")
2695+
dtype = StringDtype()
26962696
return dtype.construct_array_type()._from_sequence(objects, dtype=dtype)
26972697

26982698
elif convert_to_nullable_dtype and is_string_array(objects, skipna=True):

pandas/compat/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import pandas.compat.compressors
2727
from pandas.compat.numpy import is_numpy_dev
2828
from pandas.compat.pyarrow import (
29+
HAS_PYARROW,
2930
pa_version_under10p1,
3031
pa_version_under11p0,
3132
pa_version_under13p0,
@@ -189,6 +190,7 @@ def get_bz2_file() -> type[pandas.compat.compressors.BZ2File]:
189190
"pa_version_under14p0",
190191
"pa_version_under14p1",
191192
"pa_version_under16p0",
193+
"HAS_PYARROW",
192194
"IS64",
193195
"ISMUSL",
194196
"PY310",

pandas/compat/pyarrow.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
pa_version_under14p1 = _palv < Version("14.0.1")
1717
pa_version_under15p0 = _palv < Version("15.0.0")
1818
pa_version_under16p0 = _palv < Version("16.0.0")
19+
HAS_PYARROW = True
1920
except ImportError:
2021
pa_version_under10p1 = True
2122
pa_version_under11p0 = True
@@ -25,3 +26,4 @@
2526
pa_version_under14p1 = True
2627
pa_version_under15p0 = True
2728
pa_version_under16p0 = True
29+
HAS_PYARROW = False

pandas/conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1292,6 +1292,7 @@ def nullable_string_dtype(request):
12921292
@pytest.fixture(
12931293
params=[
12941294
"python",
1295+
"python_numpy",
12951296
pytest.param("pyarrow", marks=td.skip_if_no("pyarrow")),
12961297
pytest.param("pyarrow_numpy", marks=td.skip_if_no("pyarrow")),
12971298
]
@@ -1353,6 +1354,7 @@ def object_dtype(request):
13531354
params=[
13541355
"object",
13551356
"string[python]",
1357+
"string[python_numpy]",
13561358
pytest.param("string[pyarrow]", marks=td.skip_if_no("pyarrow")),
13571359
pytest.param("string[pyarrow_numpy]", marks=td.skip_if_no("pyarrow")),
13581360
]

pandas/core/arrays/string_.py

Lines changed: 149 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,31 @@
11
from __future__ import annotations
22

3+
import operator
34
from typing import (
45
TYPE_CHECKING,
6+
Any,
57
ClassVar,
68
Literal,
79
cast,
810
)
911

1012
import numpy as np
1113

12-
from pandas._config import get_option
14+
from pandas._config import (
15+
get_option,
16+
using_pyarrow_string_dtype,
17+
)
1318

1419
from pandas._libs import (
1520
lib,
1621
missing as libmissing,
1722
)
1823
from pandas._libs.arrays import NDArrayBacked
1924
from pandas._libs.lib import ensure_string_array
20-
from pandas.compat import pa_version_under10p1
25+
from pandas.compat import (
26+
HAS_PYARROW,
27+
pa_version_under10p1,
28+
)
2129
from pandas.compat.numpy import function as nv
2230
from pandas.util._decorators import doc
2331

@@ -81,7 +89,7 @@ class StringDtype(StorageExtensionDtype):
8189
8290
Parameters
8391
----------
84-
storage : {"python", "pyarrow", "pyarrow_numpy"}, optional
92+
storage : {"python", "pyarrow", "python_numpy", "pyarrow_numpy"}, optional
8593
If not given, the value of ``pd.options.mode.string_storage``.
8694
8795
Attributes
@@ -113,7 +121,7 @@ class StringDtype(StorageExtensionDtype):
113121
# follows NumPy semantics, which uses nan.
114122
@property
115123
def na_value(self) -> libmissing.NAType | float: # type: ignore[override]
116-
if self.storage == "pyarrow_numpy":
124+
if self.storage in ("pyarrow_numpy", "python_numpy"):
117125
return np.nan
118126
else:
119127
return libmissing.NA
@@ -122,15 +130,17 @@ def na_value(self) -> libmissing.NAType | float: # type: ignore[override]
122130

123131
def __init__(self, storage=None) -> None:
124132
if storage is None:
125-
infer_string = get_option("future.infer_string")
126-
if infer_string:
127-
storage = "pyarrow_numpy"
133+
if using_pyarrow_string_dtype():
134+
if HAS_PYARROW:
135+
storage = "pyarrow_numpy"
136+
else:
137+
storage = "python_numpy"
128138
else:
129139
storage = get_option("mode.string_storage")
130-
if storage not in {"python", "pyarrow", "pyarrow_numpy"}:
140+
if storage not in {"python", "pyarrow", "python_numpy", "pyarrow_numpy"}:
131141
raise ValueError(
132-
f"Storage must be 'python', 'pyarrow' or 'pyarrow_numpy'. "
133-
f"Got {storage} instead."
142+
"Storage must be 'python', 'pyarrow', 'python_numpy' or 'pyarrow_numpy'"
143+
f". Got {storage} instead."
134144
)
135145
if storage in ("pyarrow", "pyarrow_numpy") and pa_version_under10p1:
136146
raise ImportError(
@@ -178,6 +188,8 @@ def construct_from_string(cls, string) -> Self:
178188
return cls()
179189
elif string == "string[python]":
180190
return cls(storage="python")
191+
elif string == "string[python_numpy]":
192+
return cls(storage="python_numpy")
181193
elif string == "string[pyarrow]":
182194
return cls(storage="pyarrow")
183195
elif string == "string[pyarrow_numpy]":
@@ -207,6 +219,8 @@ def construct_array_type( # type: ignore[override]
207219
return StringArray
208220
elif self.storage == "pyarrow":
209221
return ArrowStringArray
222+
elif self.storage == "python_numpy":
223+
return StringArrayNumpySemantics
210224
else:
211225
return ArrowStringArrayNumpySemantics
212226

@@ -238,7 +252,7 @@ def __from_arrow__(
238252
# convert chunk by chunk to numpy and concatenate then, to avoid
239253
# overflow for large string data when concatenating the pyarrow arrays
240254
arr = arr.to_numpy(zero_copy_only=False)
241-
arr = ensure_string_array(arr, na_value=libmissing.NA)
255+
arr = ensure_string_array(arr, na_value=self.na_value)
242256
results.append(arr)
243257

244258
if len(chunks) == 0:
@@ -248,11 +262,7 @@ def __from_arrow__(
248262

249263
# Bypass validation inside StringArray constructor, see GH#47781
250264
new_string_array = StringArray.__new__(StringArray)
251-
NDArrayBacked.__init__(
252-
new_string_array,
253-
arr,
254-
StringDtype(storage="python"),
255-
)
265+
NDArrayBacked.__init__(new_string_array, arr, self)
256266
return new_string_array
257267

258268

@@ -360,14 +370,15 @@ class StringArray(BaseStringArray, NumpyExtensionArray): # type: ignore[misc]
360370

361371
# undo the NumpyExtensionArray hack
362372
_typ = "extension"
373+
_storage = "python"
363374

364375
def __init__(self, values, copy: bool = False) -> None:
365376
values = extract_array(values)
366377

367378
super().__init__(values, copy=copy)
368379
if not isinstance(values, type(self)):
369380
self._validate()
370-
NDArrayBacked.__init__(self, self._ndarray, StringDtype(storage="python"))
381+
NDArrayBacked.__init__(self, self._ndarray, StringDtype(storage=self._storage))
371382

372383
def _validate(self) -> None:
373384
"""Validate that we only store NA or strings."""
@@ -385,22 +396,41 @@ def _validate(self) -> None:
385396
else:
386397
lib.convert_nans_to_NA(self._ndarray)
387398

399+
def _validate_scalar(self, value):
400+
# used by NDArrayBackedExtensionIndex.insert
401+
if isna(value):
402+
return self.dtype.na_value
403+
elif not isinstance(value, str):
404+
raise TypeError(
405+
f"Cannot set non-string value '{value}' into a string array."
406+
)
407+
return value
408+
388409
@classmethod
389410
def _from_sequence(
390411
cls, scalars, *, dtype: Dtype | None = None, copy: bool = False
391412
) -> Self:
392413
if dtype and not (isinstance(dtype, str) and dtype == "string"):
393414
dtype = pandas_dtype(dtype)
394-
assert isinstance(dtype, StringDtype) and dtype.storage == "python"
415+
assert isinstance(dtype, StringDtype) and dtype.storage in (
416+
"python",
417+
"python_numpy",
418+
)
419+
else:
420+
if get_option("future.infer_string"):
421+
dtype = StringDtype(storage="python_numpy")
422+
else:
423+
dtype = StringDtype(storage="python")
395424

396425
from pandas.core.arrays.masked import BaseMaskedArray
397426

427+
na_value = dtype.na_value
398428
if isinstance(scalars, BaseMaskedArray):
399429
# avoid costly conversion to object dtype
400430
na_values = scalars._mask
401431
result = scalars._data
402432
result = lib.ensure_string_array(result, copy=copy, convert_na_value=False)
403-
result[na_values] = libmissing.NA
433+
result[na_values] = na_value
404434

405435
else:
406436
if lib.is_pyarrow_array(scalars):
@@ -409,12 +439,12 @@ def _from_sequence(
409439
# zero_copy_only to True which caused problems see GH#52076
410440
scalars = np.array(scalars)
411441
# convert non-na-likes to str, and nan-likes to StringDtype().na_value
412-
result = lib.ensure_string_array(scalars, na_value=libmissing.NA, copy=copy)
442+
result = lib.ensure_string_array(scalars, na_value=na_value, copy=copy)
413443

414444
# Manually creating new array avoids the validation step in the __init__, so is
415445
# faster. Refactor need for validation?
416446
new_string_array = cls.__new__(cls)
417-
NDArrayBacked.__init__(new_string_array, result, StringDtype(storage="python"))
447+
NDArrayBacked.__init__(new_string_array, result, dtype)
418448

419449
return new_string_array
420450

@@ -464,7 +494,7 @@ def __setitem__(self, key, value) -> None:
464494
# validate new items
465495
if scalar_value:
466496
if isna(value):
467-
value = libmissing.NA
497+
value = self.dtype.na_value
468498
elif not isinstance(value, str):
469499
raise TypeError(
470500
f"Cannot set non-string value '{value}' into a StringArray."
@@ -478,7 +508,7 @@ def __setitem__(self, key, value) -> None:
478508
mask = isna(value)
479509
if mask.any():
480510
value = value.copy()
481-
value[isna(value)] = libmissing.NA
511+
value[isna(value)] = self.dtype.na_value
482512

483513
super().__setitem__(key, value)
484514

@@ -591,9 +621,9 @@ def _cmp_method(self, other, op):
591621

592622
if op.__name__ in ops.ARITHMETIC_BINOPS:
593623
result = np.empty_like(self._ndarray, dtype="object")
594-
result[mask] = libmissing.NA
624+
result[mask] = self.dtype.na_value
595625
result[valid] = op(self._ndarray[valid], other)
596-
return StringArray(result)
626+
return self._from_backing_data(result)
597627
else:
598628
# logical
599629
result = np.zeros(len(self._ndarray), dtype="bool")
@@ -662,3 +692,97 @@ def _str_map(
662692
# or .findall returns a list).
663693
# -> We don't know the result type. E.g. `.get` can return anything.
664694
return lib.map_infer_mask(arr, f, mask.view("uint8"))
695+
696+
697+
class StringArrayNumpySemantics(StringArray):
698+
_storage = "python_numpy"
699+
700+
@classmethod
701+
def _from_sequence(
702+
cls, scalars, *, dtype: Dtype | None = None, copy: bool = False
703+
) -> Self:
704+
if dtype is None:
705+
dtype = StringDtype(storage="python_numpy")
706+
return super()._from_sequence(scalars, dtype=dtype, copy=copy)
707+
708+
def _from_backing_data(self, arr: np.ndarray) -> NumpyExtensionArray:
709+
# need to overrde NumpyExtensionArray._from_backing_data to ensure
710+
# we always preserve the dtype
711+
return NDArrayBacked._from_backing_data(self, arr)
712+
713+
def _wrap_reduction_result(self, axis: AxisInt | None, result) -> Any:
714+
# the masked_reductions use pd.NA
715+
if result is libmissing.NA:
716+
return np.nan
717+
return super()._wrap_reduction_result(axis, result)
718+
719+
def _cmp_method(self, other, op):
720+
result = super()._cmp_method(other, op)
721+
if op == operator.ne:
722+
return result.to_numpy(np.bool_, na_value=True)
723+
else:
724+
return result.to_numpy(np.bool_, na_value=False)
725+
726+
def value_counts(self, dropna: bool = True) -> Series:
727+
from pandas.core.algorithms import value_counts_internal as value_counts
728+
729+
result = value_counts(self._ndarray, sort=False, dropna=dropna)
730+
result.index = result.index.astype(self.dtype)
731+
return result
732+
733+
# ------------------------------------------------------------------------
734+
# String methods interface
735+
_str_na_value = np.nan
736+
737+
def _str_map(
738+
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
739+
):
740+
if dtype is None:
741+
dtype = self.dtype
742+
if na_value is None:
743+
na_value = self.dtype.na_value
744+
745+
mask = isna(self)
746+
arr = np.asarray(self)
747+
convert = convert and not np.all(mask)
748+
749+
if is_integer_dtype(dtype) or is_bool_dtype(dtype):
750+
# if is_integer_dtype(dtype):
751+
# na_value = np.nan
752+
# else:
753+
# na_value = False
754+
try:
755+
result = lib.map_infer_mask(
756+
arr,
757+
f,
758+
mask.view("uint8"),
759+
convert=False,
760+
na_value=na_value,
761+
dtype=np.dtype(cast(type, dtype)),
762+
)
763+
return result
764+
765+
except ValueError:
766+
result = lib.map_infer_mask(
767+
arr,
768+
f,
769+
mask.view("uint8"),
770+
convert=False,
771+
na_value=na_value,
772+
)
773+
if convert and result.dtype == object:
774+
result = lib.maybe_convert_objects(result)
775+
return result
776+
777+
elif is_string_dtype(dtype) and not is_object_dtype(dtype):
778+
# i.e. StringDtype
779+
result = lib.map_infer_mask(
780+
arr, f, mask.view("uint8"), convert=False, na_value=na_value
781+
)
782+
return type(self)(result)
783+
else:
784+
# This is when the result type is object. We reach this when
785+
# -> We know the result type is truly object (e.g. .encode returns bytes
786+
# or .findall returns a list).
787+
# -> We don't know the result type. E.g. `.get` can return anything.
788+
return lib.map_infer_mask(arr, f, mask.view("uint8"))

pandas/core/config_init.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,9 @@ def is_terminal() -> bool:
460460
"string_storage",
461461
"python",
462462
string_storage_doc,
463-
validator=is_one_of_factory(["python", "pyarrow", "pyarrow_numpy"]),
463+
validator=is_one_of_factory(
464+
["python", "pyarrow", "python_numpy", "pyarrow_numpy"]
465+
),
464466
)
465467

466468

@@ -858,7 +860,7 @@ def register_converter_cb(key: str) -> None:
858860
with cf.config_prefix("future"):
859861
cf.register_option(
860862
"infer_string",
861-
False,
863+
True if os.environ.get("PANDAS_FUTURE_INFER_STRING", "0") == "1" else False,
862864
"Whether to infer sequence of str objects as pyarrow string "
863865
"dtype, which will be the default in pandas 3.0 "
864866
"(at which point this option will be deprecated).",

0 commit comments

Comments
 (0)