Skip to content

Commit 2f1bc37

Browse files
rename storage option and add na_value keyword
1 parent 4fb94bb commit 2f1bc37

File tree

1 file changed

+53
-24
lines changed

1 file changed

+53
-24
lines changed

pandas/core/arrays/string_.py

Lines changed: 53 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99

1010
import numpy as np
1111

12-
from pandas._config import get_option
12+
from pandas._config import (
13+
get_option,
14+
using_pyarrow_string_dtype,
15+
)
1316

1417
from pandas._libs import (
1518
lib,
@@ -83,6 +86,7 @@ class StringDtype(StorageExtensionDtype):
8386
----------
8487
storage : {"python", "pyarrow", "pyarrow_numpy"}, optional
8588
If not given, the value of ``pd.options.mode.string_storage``.
89+
na_value :
8690
8791
Attributes
8892
----------
@@ -113,30 +117,49 @@ class StringDtype(StorageExtensionDtype):
113117
# follows NumPy semantics, which uses nan.
114118
@property
115119
def na_value(self) -> libmissing.NAType | float: # type: ignore[override]
116-
if self.storage == "pyarrow_numpy":
117-
return np.nan
118-
else:
119-
return libmissing.NA
120+
return self._na_value
120121

121122
_metadata = ("storage",)
122123

123-
def __init__(self, storage=None) -> None:
124-
if storage is None:
125-
infer_string = get_option("future.infer_string")
126-
if infer_string:
127-
storage = "pyarrow_numpy"
124+
def __init__(self, storage=None, na_value=None) -> None:
125+
if not (
126+
na_value is None or (isinstance(na_value, float) and np.isnan(na_value))
127+
):
128+
raise ValueError(
129+
"'na_value' must be the default value or pd.NA, got {na_value}"
130+
)
131+
132+
# infer defaults
133+
if storage is None and na_value is None:
134+
if using_pyarrow_string_dtype():
135+
storage = "pyarrow"
136+
na_value = np.nan
128137
else:
129138
storage = get_option("mode.string_storage")
130-
if storage not in {"python", "pyarrow", "pyarrow_numpy"}:
139+
na_value = libmissing.NA
140+
elif storage is None:
141+
# in this case na_value is NaN
142+
storage = get_option("mode.string_storage")
143+
elif na_value is None:
144+
na_value = np.nan if using_pyarrow_string_dtype() else libmissing.NA
145+
if na_value is not libmissing.NA and storage == "python":
146+
raise NotImplementedError(
147+
"'python' mode for na_value of NaN not yet implemented"
148+
)
149+
150+
if storage == "pyarrow_numpy":
151+
# TODO raise a deprecation warning
152+
storage = "pyarrow"
153+
if storage not in {"python", "pyarrow"}:
131154
raise ValueError(
132-
f"Storage must be 'python', 'pyarrow' or 'pyarrow_numpy'. "
133-
f"Got {storage} instead."
155+
f"Storage must be 'python' or 'pyarrow'. Got {storage} instead."
134156
)
135-
if storage in ("pyarrow", "pyarrow_numpy") and pa_version_under10p1:
157+
if storage == "pyarrow" and pa_version_under10p1:
136158
raise ImportError(
137159
"pyarrow>=10.0.1 is required for PyArrow backed StringArray."
138160
)
139161
self.storage = storage
162+
self._na_value = na_value
140163

141164
@property
142165
def type(self) -> type[str]:
@@ -176,11 +199,14 @@ def construct_from_string(cls, string) -> Self:
176199
)
177200
if string == "string":
178201
return cls()
202+
elif string == "String":
203+
return cls(na_value=np.nan)
179204
elif string == "string[python]":
180-
return cls(storage="python")
205+
return cls(storage="python", na_value=np.nan)
181206
elif string == "string[pyarrow]":
182-
return cls(storage="pyarrow")
207+
return cls(storage="pyarrow", na_value=np.nan)
183208
elif string == "string[pyarrow_numpy]":
209+
# TODO deprecate
184210
return cls(storage="pyarrow_numpy")
185211
else:
186212
raise TypeError(f"Cannot construct a '{cls.__name__}' from '{string}'")
@@ -205,10 +231,10 @@ def construct_array_type( # type: ignore[override]
205231

206232
if self.storage == "python":
207233
return StringArray
208-
elif self.storage == "pyarrow":
209-
return ArrowStringArray
210-
else:
234+
elif self.storage == "pyarrow" and self._na_value is libmissing.NA:
211235
return ArrowStringArrayNumpySemantics
236+
else:
237+
return ArrowStringArray
212238

213239
def __from_arrow__(
214240
self, array: pyarrow.Array | pyarrow.ChunkedArray
@@ -217,13 +243,16 @@ def __from_arrow__(
217243
Construct StringArray from pyarrow Array/ChunkedArray.
218244
"""
219245
if self.storage == "pyarrow":
220-
from pandas.core.arrays.string_arrow import ArrowStringArray
246+
if self._na_value is libmissing.NA:
247+
from pandas.core.arrays.string_arrow import (
248+
ArrowStringArrayNumpySemantics,
249+
)
221250

222-
return ArrowStringArray(array)
223-
elif self.storage == "pyarrow_numpy":
224-
from pandas.core.arrays.string_arrow import ArrowStringArrayNumpySemantics
251+
return ArrowStringArrayNumpySemantics(array)
252+
else:
253+
from pandas.core.arrays.string_arrow import ArrowStringArray
225254

226-
return ArrowStringArrayNumpySemantics(array)
255+
return ArrowStringArray(array)
227256
else:
228257
import pyarrow
229258

0 commit comments

Comments
 (0)