Skip to content

Commit 60a8eee

Browse files
committed
TST(string dtype): Make str.decode return str dtype
1 parent a81d52f commit 60a8eee

File tree

3 files changed

+8
-10
lines changed

3 files changed

+8
-10
lines changed

pandas/core/strings/accessor.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
import numpy as np
1414

15+
from pandas._config import get_option
16+
1517
from pandas._libs import lib
1618
from pandas._typing import (
1719
AlignJoin,
@@ -399,7 +401,9 @@ def cons_row(x):
399401
# This is a mess.
400402
_dtype: DtypeObj | str | None = dtype
401403
vdtype = getattr(result, "dtype", None)
402-
if self._is_string:
404+
if _dtype is not None:
405+
pass
406+
elif self._is_string:
403407
if is_bool_dtype(vdtype):
404408
_dtype = result.dtype
405409
elif returns_string:
@@ -2140,9 +2144,9 @@ def decode(self, encoding, errors: str = "strict"):
21402144
decoder = codecs.getdecoder(encoding)
21412145
f = lambda x: decoder(x, errors)[0]
21422146
arr = self._data.array
2143-
# assert isinstance(arr, (StringArray,))
21442147
result = arr._str_map(f)
2145-
return self._wrap_result(result)
2148+
dtype = "str" if get_option("future.infer_string") else None
2149+
return self._wrap_result(result, dtype=dtype)
21462150

21472151
@forbid_nonstring_types(["bytes"])
21482152
def encode(self, encoding, errors: str = "strict"):

pandas/tests/io/sas/test_sas7bdat.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
import numpy as np
88
import pytest
99

10-
from pandas._config import using_string_dtype
11-
1210
from pandas.compat._constants import (
1311
IS64,
1412
WASM,
@@ -20,10 +18,6 @@
2018

2119
from pandas.io.sas.sas7bdat import SAS7BDATReader
2220

23-
pytestmark = pytest.mark.xfail(
24-
using_string_dtype(), reason="TODO(infer_string)", strict=False
25-
)
26-
2721

2822
@pytest.fixture
2923
def dirpath(datapath):

pandas/tests/strings/test_strings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,7 @@ def test_string_slice_out_of_bounds(any_string_dtype):
566566
def test_encode_decode(any_string_dtype):
567567
ser = Series(["a", "b", "a\xe4"], dtype=any_string_dtype).str.encode("utf-8")
568568
result = ser.str.decode("utf-8")
569-
expected = ser.map(lambda x: x.decode("utf-8")).astype(object)
569+
expected = Series(["a", "b", "a\xe4"], dtype="str")
570570
tm.assert_series_equal(result, expected)
571571

572572

0 commit comments

Comments
 (0)