Skip to content

Commit 283eda9

Browse files
committed
Rework
1 parent f758eb1 commit 283eda9

File tree

3 files changed

+32
-24
lines changed

3 files changed

+32
-24
lines changed

pandas/core/groupby/ops.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
ensure_uint64,
4545
is_1d_only_ea_dtype,
4646
)
47-
from pandas.core.dtypes.dtypes import ArrowDtype
4847
from pandas.core.dtypes.missing import (
4948
isna,
5049
maybe_fill,
@@ -956,19 +955,22 @@ def agg_series(
956955
-------
957956
np.ndarray or ExtensionArray
958957
"""
959-
960958
result = self._aggregate_series_pure_python(obj, func)
961959
npvalues = lib.maybe_convert_objects(result, try_float=False)
962960

963961
if isinstance(obj._values, ArrowExtensionArray):
964-
out = maybe_cast_pointwise_result(
965-
npvalues, obj.dtype, numeric_only=True, same_dtype=preserve_dtype
966-
)
967-
import pyarrow as pa
962+
from pandas.core.dtypes.common import is_string_dtype
968963

969-
if isinstance(out.dtype, ArrowDtype) and pa.types.is_struct(
970-
out.dtype.pyarrow_dtype
971-
):
964+
if not is_string_dtype(obj.dtype) or is_string_dtype(npvalues):
965+
out = maybe_cast_pointwise_result(
966+
npvalues, obj.dtype, numeric_only=True, same_dtype=preserve_dtype
967+
)
968+
969+
# if isinstance(out.dtype, ArrowDtype) and pa.types.is_struct(
970+
# out.dtype.pyarrow_dtype
971+
# ):
972+
# out = npvalues
973+
else:
972974
out = npvalues
973975

974976
elif not isinstance(obj._values, np.ndarray):

pandas/tests/groupby/aggregate/test_aggregate.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import pytest
1111

1212
from pandas.errors import SpecificationError
13+
import pandas.util._test_decorators as td
1314

1415
from pandas.core.dtypes.common import is_integer_dtype
1516

@@ -23,6 +24,7 @@
2324
to_datetime,
2425
)
2526
import pandas._testing as tm
27+
from pandas.arrays import ArrowExtensionArray
2628
from pandas.core.groupby.grouper import Grouping
2729

2830

@@ -1812,16 +1814,18 @@ def test_groupby_aggregation_func_list_multi_index_duplicate_columns():
18121814
@pytest.mark.parametrize(
18131815
"input_dtype, output_dtype",
18141816
[
1817+
# With NumPy arrays, the results from the UDF would be e.g. np.float32 scalars
1818+
# which we can therefore preserve. However with PyArrow arrays, the results are
1819+
# Python scalars so we have no information about size or uint vs int.
18151820
("float[pyarrow]", "double[pyarrow]"),
18161821
("int64[pyarrow]", "int64[pyarrow]"),
18171822
("uint64[pyarrow]", "int64[pyarrow]"),
18181823
("bool[pyarrow]", "bool[pyarrow]"),
18191824
],
18201825
)
18211826
def test_agg_lambda_pyarrow_dtype_conversion(input_dtype, output_dtype):
1822-
# GH#53030
1823-
# test numpy dtype conversion back to pyarrow dtype
1824-
# complexes, floats, ints, uints, object
1827+
# GH#59601
1828+
# Test PyArrow dtype conversion back to PyArrow dtype
18251829
df = DataFrame(
18261830
{
18271831
"A": ["c1", "c2", "c3", "c1", "c2", "c3"],
@@ -1839,7 +1843,7 @@ def test_agg_lambda_pyarrow_dtype_conversion(input_dtype, output_dtype):
18391843

18401844

18411845
def test_agg_lambda_complex128_dtype_conversion():
1842-
# GH#53030
1846+
# GH#59601
18431847
df = DataFrame(
18441848
{"A": ["c1", "c2", "c3"], "B": pd.array([100, 200, 255], "int64[pyarrow]")}
18451849
)
@@ -1877,8 +1881,11 @@ def test_agg_lambda_numpy_uint64_to_pyarrow_dtype_conversion():
18771881
tm.assert_frame_equal(result, expected)
18781882

18791883

1884+
@td.skip_if_no("pyarrow")
18801885
def test_agg_lambda_pyarrow_struct_to_object_dtype_conversion():
18811886
# GH#53030
1887+
import pyarrow as pa
1888+
18821889
df = DataFrame(
18831890
{
18841891
"A": ["c1", "c2", "c3"],
@@ -1888,8 +1895,10 @@ def test_agg_lambda_pyarrow_struct_to_object_dtype_conversion():
18881895
gb = df.groupby("A")
18891896
result = gb.agg(lambda x: {"number": 1})
18901897

1898+
arr = pa.array([{"number": 1}, {"number": 1}, {"number": 1}])
18911899
expected = DataFrame(
1892-
{"B": pd.array([{"number": 1}, {"number": 1}, {"number": 1}], dtype="object")},
1900+
{"B": ArrowExtensionArray(arr)},
18931901
index=Index(["c1", "c2", "c3"], name="A"),
18941902
)
1903+
18951904
tm.assert_frame_equal(result, expected)

pandas/tests/groupby/test_groupby.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
)
2727
import pandas._testing as tm
2828
from pandas.core.arrays import BooleanArray
29-
from pandas.core.arrays.string_arrow import ArrowStringArrayNumpySemantics
3029
import pandas.core.common as com
3130

3231
pytestmark = pytest.mark.filterwarnings("ignore:Mean of empty slice:RuntimeWarning")
@@ -2435,30 +2434,28 @@ def test_rolling_wrong_param_min_period():
24352434

24362435
def test_by_column_values_with_same_starting_value(any_string_dtype):
24372436
# GH29635
2437+
dtype = any_string_dtype
24382438
df = DataFrame(
24392439
{
24402440
"Name": ["Thomas", "Thomas", "Thomas John"],
24412441
"Credit": [1200, 1300, 900],
2442-
"Mood": Series(["sad", "happy", "happy"], dtype=any_string_dtype),
2442+
"Mood": Series(["sad", "happy", "happy"], dtype=dtype),
24432443
}
24442444
)
24452445
aggregate_details = {"Mood": Series.mode, "Credit": "sum"}
24462446

24472447
result = df.groupby(["Name"]).agg(aggregate_details)
2448-
expected_result = DataFrame(
2448+
expected = DataFrame(
24492449
{
24502450
"Mood": [["happy", "sad"], "happy"],
24512451
"Credit": [2500, 900],
24522452
"Name": ["Thomas", "Thomas John"],
24532453
},
24542454
).set_index("Name")
2455-
if dtype == "string[pyarrow_numpy]":
2456-
import pyarrow as pa
2457-
2458-
mood_values = ArrowStringArrayNumpySemantics(pa.array(["happy", "sad"]))
2459-
expected_result["Mood"] = [mood_values, "happy"]
2460-
expected_result["Mood"] = expected_result["Mood"].astype(dtype)
2461-
tm.assert_frame_equal(result, expected_result)
2455+
if getattr(dtype, "storage", None) == "pyarrow":
2456+
mood_values = pd.array(["happy", "sad"], dtype=dtype)
2457+
expected["Mood"] = [mood_values, "happy"]
2458+
tm.assert_frame_equal(result, expected)
24622459

24632460

24642461
def test_groupby_none_in_first_mi_level():

0 commit comments

Comments
 (0)