Skip to content

Commit 289cd20

Browse files
String dtype: allow string dtype for non-raw apply with numba engine
1 parent 2419343 commit 289cd20

File tree

3 files changed

+6
-11
lines changed

3 files changed

+6
-11
lines changed

pandas/core/_numba/extensions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@
5353
@contextmanager
5454
def set_numba_data(index: Index):
5555
numba_data = index._data
56-
if numba_data.dtype == object:
56+
if numba_data.dtype in (object, "string"):
57+
numba_data = np.asarray(numba_data)
5758
if not lib.is_string_array(numba_data):
5859
raise ValueError(
5960
"The numba engine only supports using string or numeric column names"

pandas/core/apply.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1172,12 +1172,12 @@ def apply_with_numba(self) -> dict[int, Any]:
11721172
from pandas.core._numba.extensions import set_numba_data
11731173

11741174
index = self.obj.index
1175-
if index.dtype == "string":
1176-
index = index.astype(object)
1175+
# if index.dtype == "string":
1176+
# index = index.astype(object)
11771177

11781178
columns = self.obj.columns
1179-
if columns.dtype == "string":
1180-
columns = columns.astype(object)
1179+
# if columns.dtype == "string":
1180+
# columns = columns.astype(object)
11811181

11821182
# Convert from numba dict to regular dict
11831183
# Our isinstance checks in the df constructor don't pass for numbas typed dict

pandas/tests/apply/test_frame_apply.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66

77
from pandas._config import using_string_dtype
88

9-
from pandas.compat import HAS_PYARROW
10-
119
from pandas.core.dtypes.dtypes import CategoricalDtype
1210

1311
import pandas as pd
@@ -65,7 +63,6 @@ def test_apply(float_frame, engine, request):
6563
assert result.index is float_frame.index
6664

6765

68-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
6966
@pytest.mark.parametrize("axis", [0, 1])
7067
@pytest.mark.parametrize("raw", [True, False])
7168
@pytest.mark.parametrize("nopython", [True, False])
@@ -1247,9 +1244,6 @@ def test_agg_multiple_mixed():
12471244
tm.assert_frame_equal(result, expected)
12481245

12491246

1250-
@pytest.mark.xfail(
1251-
using_string_dtype() and not HAS_PYARROW, reason="TODO(infer_string)"
1252-
)
12531247
def test_agg_multiple_mixed_raises():
12541248
# GH 20909
12551249
mdf = DataFrame(

0 commit comments

Comments
 (0)