Skip to content

Commit 1bf8a69

Browse files
add specific tests
1 parent 0899e4e commit 1bf8a69

File tree

2 files changed

+69
-3
lines changed

2 files changed

+69
-3
lines changed

pandas/core/arrays/string_.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -285,16 +285,14 @@ def construct_array_type( # type: ignore[override]
285285
return ArrowStringArrayNumpySemantics
286286

287287
def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None:
288-
allowed_numpy_kinds = {"S", "U"}
289-
290288
storages = set()
291289
na_values = set()
292290

293291
for dtype in dtypes:
294292
if isinstance(dtype, StringDtype):
295293
storages.add(dtype.storage)
296294
na_values.add(dtype.na_value)
297-
elif isinstance(dtype, np.dtype) and dtype.kind in allowed_numpy_kinds:
295+
elif isinstance(dtype, np.dtype) and dtype.kind == "U":
298296
continue
299297
else:
300298
return None
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import numpy as np
2+
import pytest
3+
4+
from pandas.compat import HAS_PYARROW
5+
6+
from pandas.core.dtypes.cast import find_common_type
7+
8+
import pandas as pd
9+
import pandas._testing as tm
10+
11+
12+
@pytest.mark.parametrize(
13+
"to_concat_dtypes, result_dtype",
14+
[
15+
# same types
16+
([("pyarrow", pd.NA), ("pyarrow", pd.NA)], ("pyarrow", pd.NA)),
17+
([("pyarrow", np.nan), ("pyarrow", np.nan)], ("pyarrow", np.nan)),
18+
([("python", pd.NA), ("python", pd.NA)], ("python", pd.NA)),
19+
([("python", np.nan), ("python", np.nan)], ("python", np.nan)),
20+
# pyarrow preference
21+
([("pyarrow", pd.NA), ("python", pd.NA)], ("pyarrow", pd.NA)),
22+
# NA preference
23+
([("python", pd.NA), ("python", np.nan)], ("python", pd.NA)),
24+
],
25+
)
26+
def test_concat_series(request, to_concat_dtypes, result_dtype):
27+
if any(storage == "pyarrow" for storage, _ in to_concat_dtypes) and not HAS_PYARROW:
28+
pytest.skip("Could not import 'pyarrow'")
29+
30+
ser_list = [
31+
pd.Series(["a", "b", None], dtype=pd.StringDtype(storage, na_value))
32+
for storage, na_value in to_concat_dtypes
33+
]
34+
35+
result = pd.concat(ser_list, ignore_index=True)
36+
expected = pd.Series(
37+
["a", "b", None, "a", "b", None], dtype=pd.StringDtype(*result_dtype)
38+
)
39+
tm.assert_series_equal(result, expected)
40+
41+
# order doesn't matter for result
42+
result = pd.concat(ser_list[::1], ignore_index=True)
43+
tm.assert_series_equal(result, expected)
44+
45+
46+
def test_concat_with_object(string_dtype_arguments):
47+
# _get_common_dtype cannot inspect values, so object dtype with strings still
48+
# results in object dtype
49+
result = pd.concat(
50+
[
51+
pd.Series(["a", "b", None], dtype=pd.StringDtype(*string_dtype_arguments)),
52+
pd.Series(["a", "b", None], dtype=object),
53+
]
54+
)
55+
assert result.dtype == np.dtype("object")
56+
57+
58+
def test_concat_with_numpy(string_dtype_arguments):
59+
# common type with a numpy string dtype always preserves the pandas string dtype
60+
dtype = pd.StringDtype(*string_dtype_arguments)
61+
assert find_common_type([dtype, np.dtype("U")]) == dtype
62+
assert find_common_type([np.dtype("U"), dtype]) == dtype
63+
assert find_common_type([dtype, np.dtype("U10")]) == dtype
64+
assert find_common_type([np.dtype("U10"), dtype]) == dtype
65+
66+
# with any other numpy dtype -> object
67+
assert find_common_type([dtype, np.dtype("S")]) == np.dtype("object")
68+
assert find_common_type([dtype, np.dtype("int64")]) == np.dtype("object")

0 commit comments

Comments
 (0)