Skip to content

Commit e19455d

Browse files
committed
ENH: Add dtype argument to str.decode
1 parent 6bcd303 commit e19455d

File tree

3 files changed

+36
-2
lines changed

3 files changed

+36
-2
lines changed

doc/source/whatsnew/v2.3.0.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ Other enhancements
3939
- :meth:`~Series.to_hdf` and :meth:`~DataFrame.to_hdf` now round-trip with ``StringDtype`` (:issue:`60663`)
4040
- The :meth:`~Series.cumsum`, :meth:`~Series.cummin`, and :meth:`~Series.cummax` reductions are now implemented for ``StringDtype`` columns when backed by PyArrow (:issue:`60633`)
4141
- The :meth:`~Series.sum` reduction is now implemented for ``StringDtype`` columns (:issue:`59853`)
42+
- The :meth:`Series.str.decode` has gained the argument ``dtype`` to control the dtype of the result (:issue:`???`)
4243

4344
.. ---------------------------------------------------------------------------
4445
.. _whatsnew_230.notable_bug_fixes:

pandas/core/strings/accessor.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
is_numeric_dtype,
3535
is_object_dtype,
3636
is_re,
37+
is_string_dtype,
3738
)
3839
from pandas.core.dtypes.dtypes import (
3940
ArrowDtype,
@@ -2102,7 +2103,7 @@ def slice_replace(self, start=None, stop=None, repl=None):
21022103
result = self._data.array._str_slice_replace(start, stop, repl)
21032104
return self._wrap_result(result)
21042105

2105-
def decode(self, encoding, errors: str = "strict"):
2106+
def decode(self, encoding, errors: str = "strict", dtype: str | DtypeObj = None):
21062107
"""
21072108
Decode character string in the Series/Index using indicated encoding.
21082109
@@ -2116,6 +2117,10 @@ def decode(self, encoding, errors: str = "strict"):
21162117
errors : str, optional
21172118
Specifies the error handling scheme.
21182119
Possible values are those supported by :meth:`bytes.decode`.
2120+
dtype : str or dtype, optional
2121+
The dtype of the result. When not ``None``, must be either a string or
2122+
object dtype. When ``None``, the dtype of the result is determined by
2123+
``pd.options.future.infer_string``.
21192124
21202125
Returns
21212126
-------
@@ -2137,6 +2142,12 @@ def decode(self, encoding, errors: str = "strict"):
21372142
2 ()
21382143
dtype: object
21392144
"""
2145+
if (
2146+
dtype is not None
2147+
and not is_string_dtype(dtype)
2148+
and not is_object_dtype(dtype)
2149+
):
2150+
raise ValueError(f"dtype must be string or object, got {dtype=}")
21402151
# TODO: Add a similar _bytes interface.
21412152
if encoding in _cpython_optimized_decoders:
21422153
# CPython optimized implementation
@@ -2146,7 +2157,8 @@ def decode(self, encoding, errors: str = "strict"):
21462157
f = lambda x: decoder(x, errors)[0]
21472158
arr = self._data.array
21482159
result = arr._str_map(f)
2149-
dtype = "str" if get_option("future.infer_string") else None
2160+
if dtype is None:
2161+
dtype = "str" if get_option("future.infer_string") else None
21502162
return self._wrap_result(result, dtype=dtype)
21512163

21522164
@forbid_nonstring_types(["bytes"])

pandas/tests/strings/test_strings.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,27 @@ def test_decode_errors_kwarg():
601601
tm.assert_series_equal(result, expected)
602602

603603

604+
def test_decode_string_dtype(string_dtype):
605+
ser = Series([b"a", b"b"])
606+
result = ser.str.decode("utf-8", dtype=string_dtype)
607+
expected = Series(["a", "b"], dtype=string_dtype)
608+
tm.assert_series_equal(result, expected)
609+
610+
611+
def test_decode_object_dtype(object_dtype):
612+
ser = Series([b"a", rb"\ud800"])
613+
result = ser.str.decode("utf-8", dtype=object_dtype)
614+
expected = Series(["a", r"\ud800"], dtype=object_dtype)
615+
tm.assert_series_equal(result, expected)
616+
617+
618+
def test_decode_bad_dtype():
619+
ser = Series([b"a", b"b"])
620+
msg = "dtype must be string or object, got dtype='int64'"
621+
with pytest.raises(ValueError, match=msg):
622+
ser.str.decode("utf-8", dtype="int64")
623+
624+
604625
@pytest.mark.parametrize(
605626
"form, expected",
606627
[

0 commit comments

Comments
 (0)