Skip to content

Commit 0899e4e

Browse files
String dtype: implement _get_common_dtype
1 parent 078e11f commit 0899e4e

File tree

2 files changed

+30
-4
lines changed

2 files changed

+30
-4
lines changed

pandas/core/arrays/string_.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def __init__(
171171
# a consistent NaN value (and we can use `dtype.na_value is np.nan`)
172172
na_value = np.nan
173173
elif na_value is not libmissing.NA:
174-
raise ValueError("'na_value' must be np.nan or pd.NA, got {na_value}")
174+
raise ValueError(f"'na_value' must be np.nan or pd.NA, got {na_value}")
175175

176176
self.storage = storage
177177
self._na_value = na_value
@@ -284,6 +284,35 @@ def construct_array_type( # type: ignore[override]
284284
else:
285285
return ArrowStringArrayNumpySemantics
286286

287+
def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None:
288+
allowed_numpy_kinds = {"S", "U"}
289+
290+
storages = set()
291+
na_values = set()
292+
293+
for dtype in dtypes:
294+
if isinstance(dtype, StringDtype):
295+
storages.add(dtype.storage)
296+
na_values.add(dtype.na_value)
297+
elif isinstance(dtype, np.dtype) and dtype.kind in allowed_numpy_kinds:
298+
continue
299+
else:
300+
return None
301+
302+
if len(storages) == 2:
303+
# if both python and pyarrow storage -> priority to pyarrow
304+
storage = "pyarrow"
305+
else:
306+
storage = next(iter(storages))
307+
308+
if len(na_values) == 2:
309+
# if both NaN and NA -> priority to NA
310+
na_value = libmissing.NA
311+
else:
312+
na_value = next(iter(na_values))
313+
314+
return StringDtype(storage=storage, na_value=na_value)
315+
287316
def __from_arrow__(
288317
self, array: pyarrow.Array | pyarrow.ChunkedArray
289318
) -> BaseStringArray:

pandas/tests/arrays/categorical/test_api.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
import numpy as np
44
import pytest
55

6-
from pandas._config import using_string_dtype
7-
86
from pandas.compat import PY311
97

108
from pandas import (
@@ -151,7 +149,6 @@ def test_reorder_categories_raises(self, new_categories):
151149
with pytest.raises(ValueError, match=msg):
152150
cat.reorder_categories(new_categories)
153151

154-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
155152
def test_add_categories(self):
156153
cat = Categorical(["a", "b", "c", "a"], ordered=True)
157154
old = cat.copy()

0 commit comments

Comments
 (0)