Skip to content

Commit a9650bb

Browse files
fix dtype equality to take into account the NaN vs NA
1 parent 8587297 commit a9650bb

File tree

3 files changed

+56
-8
lines changed

3 files changed

+56
-8
lines changed

pandas/core/arrays/string_.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -120,19 +120,13 @@ class StringDtype(StorageExtensionDtype):
120120
def na_value(self) -> libmissing.NAType | float: # type: ignore[override]
121121
return self._na_value
122122

123-
_metadata = ("storage",)
123+
_metadata = ("storage", "na_value")
124124

125125
def __init__(
126126
self,
127127
storage: str | None = None,
128128
na_value: libmissing.NAType | float = libmissing.NA,
129129
) -> None:
130-
if not (
131-
na_value is libmissing.NA
132-
or (isinstance(na_value, float) and np.isnan(na_value))
133-
):
134-
raise ValueError("'na_value' must be np.nan or pd.NA, got {na_value}")
135-
136130
# infer defaults
137131
if storage is None:
138132
if using_string_dtype():
@@ -145,6 +139,7 @@ def __init__(
145139
storage = "pyarrow"
146140
na_value = np.nan
147141

142+
# validate options
148143
if storage not in {"python", "pyarrow"}:
149144
raise ValueError(
150145
f"Storage must be 'python' or 'pyarrow'. Got {storage} instead."
@@ -153,9 +148,35 @@ def __init__(
153148
raise ImportError(
154149
"pyarrow>=10.0.1 is required for PyArrow backed StringArray."
155150
)
151+
152+
if isinstance(na_value, float) and np.isnan(na_value):
153+
# when passed a NaN value, always set to np.nan to ensure we use
154+
# a consistent NaN value (and we can use `dtype.na_value is np.nan`)
155+
na_value = np.nan
156+
elif na_value is not libmissing.NA:
157+
raise ValueError("'na_value' must be np.nan or pd.NA, got {na_value}")
158+
156159
self.storage = storage
157160
self._na_value = na_value
158161

162+
def __eq__(self, other: object) -> bool:
163+
# we need to override the base class __eq__ because na_value (NA or NaN)
164+
# cannot be checked with normal `==`
165+
if isinstance(other, str):
166+
if other == self.name:
167+
return True
168+
try:
169+
other = self.construct_from_string(other)
170+
except TypeError:
171+
return False
172+
if isinstance(other, type(self)):
173+
return self.storage == other.storage and self.na_value is other.na_value
174+
return False
175+
176+
def __hash__(self) -> int:
177+
# need to override __hash__ as well because of overriding __eq__
178+
return super().__hash__()
179+
159180
@property
160181
def type(self) -> type[str]:
161182
return str

pandas/tests/arrays/string_/test_string.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,27 @@ def cls(dtype):
3232
return dtype.construct_array_type()
3333

3434

35+
def test_dtype_equality():
36+
pytest.importorskip("pyarrow")
37+
38+
dtype1 = pd.StringDtype("python")
39+
dtype2 = pd.StringDtype("pyarrow")
40+
dtype3 = pd.StringDtype("pyarrow", na_value=np.nan)
41+
42+
assert dtype1 == pd.StringDtype("python", na_value=pd.NA)
43+
assert dtype1 != dtype2
44+
assert dtype1 != dtype3
45+
46+
assert dtype2 == pd.StringDtype("pyarrow", na_value=pd.NA)
47+
assert dtype2 != dtype1
48+
assert dtype2 != dtype3
49+
50+
assert dtype3 == pd.StringDtype("pyarrow", na_value=np.nan)
51+
assert dtype3 == pd.StringDtype("pyarrow", na_value=float("nan"))
52+
assert dtype3 != dtype1
53+
assert dtype3 != dtype2
54+
55+
3556
def test_repr(dtype):
3657
df = pd.DataFrame({"A": pd.array(["a", pd.NA, "b"], dtype=dtype)})
3758
if dtype.na_value is np.nan:

pandas/tests/extension/test_string.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,15 @@ def data_for_grouping(dtype, chunked):
9696

9797
class TestStringArray(base.ExtensionTests):
9898
def test_eq_with_str(self, dtype):
99-
assert dtype == f"string[{dtype.storage}]"
10099
super().test_eq_with_str(dtype)
101100

101+
if dtype.na_value is pd.NA:
102+
# only the NA-variant supports parametrized string alias
103+
assert dtype == f"string[{dtype.storage}]"
104+
elif dtype.storage == "pyarrow":
105+
# TODO(infer_string) deprecate this
106+
assert dtype == "string[pyarrow_numpy]"
107+
102108
def test_is_not_string_type(self, dtype):
103109
# Different from BaseDtypeTests.test_is_not_string_type
104110
# because StringDtype is a string type

0 commit comments

Comments
 (0)