Skip to content

Commit e007299

Browse files
add tests for get_loc + fix for NA variant of string dtype
1 parent 6892f83 commit e007299

File tree

3 files changed

+40
-3
lines changed

3 files changed

+40
-3
lines changed

pandas/_libs/index.pyx

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -559,11 +559,18 @@ cdef class StringEngine(IndexEngine):
559559

560560
cdef class StringObjectEngine(ObjectEngine):
561561

562+
cdef:
563+
object na_value
564+
565+
def __init__(self, ndarray values, na_value):
566+
super().__init__(values)
567+
self.na_value = na_value
568+
562569
cdef _check_type(self, object val):
563570
if isinstance(val, str):
564571
return val
565572
elif checknull(val):
566-
return np.nan
573+
return self.na_value
567574
else:
568575
raise KeyError(val)
569576

pandas/core/indexes/base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -875,9 +875,8 @@ def _engine(
875875
# error: Item "ExtensionArray" of "Union[ExtensionArray,
876876
# ndarray[Any, Any]]" has no attribute "_ndarray" [union-attr]
877877
target_values = self._data._ndarray # type: ignore[union-attr]
878-
# TODO re-enable StringEngine for string dtype
879878
elif is_string_dtype(self.dtype) and not is_object_dtype(self.dtype):
880-
return libindex.StringObjectEngine(target_values)
879+
return libindex.StringObjectEngine(target_values, self.dtype.na_value)
881880

882881
# error: Argument 1 to "ExtensionEngine" has incompatible type
883882
# "ndarray[Any, Any]"; expected "ExtensionArray"

pandas/tests/indexes/string/test_indexing.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,37 @@
66
import pandas._testing as tm
77

88

9+
class TestGetLoc:
10+
def test_get_loc(self, any_string_dtype):
11+
index = Index(["a", "b", "c"], dtype=any_string_dtype)
12+
assert index.get_loc("b") == 1
13+
14+
def test_get_loc_raises(self, any_string_dtype):
15+
index = Index(["a", "b", "c"], dtype=any_string_dtype)
16+
with pytest.raises(KeyError, match="d"):
17+
index.get_loc("d")
18+
19+
def test_get_loc_invalid_value(self, any_string_dtype):
20+
index = Index(["a", "b", "c"], dtype=any_string_dtype)
21+
with pytest.raises(KeyError, match="1"):
22+
index.get_loc(1)
23+
24+
def test_get_loc_non_unique(self, any_string_dtype):
25+
index = Index(["a", "b", "a"], dtype=any_string_dtype)
26+
result = index.get_loc("a")
27+
expected = np.array([True, False, True])
28+
tm.assert_numpy_array_equal(result, expected)
29+
30+
def test_get_loc_non_missing(self, any_string_dtype, nulls_fixture):
31+
index = Index(["a", "b", "c"], dtype=any_string_dtype)
32+
with pytest.raises(KeyError):
33+
index.get_loc(nulls_fixture)
34+
35+
def test_get_loc_missing(self, any_string_dtype, nulls_fixture):
36+
index = Index(["a", "b", nulls_fixture], dtype=any_string_dtype)
37+
assert index.get_loc(nulls_fixture) == 2
38+
39+
940
class TestGetIndexer:
1041
@pytest.mark.parametrize(
1142
"method,expected",

0 commit comments

Comments
 (0)