Skip to content

Commit bebc442

Browse files
author
Kei
committed
Update tests to check expected inferred dtype instead of inputy dtype
1 parent abd0adf commit bebc442

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

pandas/tests/extension/base/groupby.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,18 @@ def test_groupby_extension_agg(self, as_index, data_for_grouping):
5858
expected = pd.DataFrame({"B": uniques, "A": exp_vals})
5959
tm.assert_frame_equal(result, expected)
6060

61-
def test_groupby_agg_extension(self, data_for_grouping):
61+
def test_groupby_agg_extension(
62+
self, data_for_grouping, expected_inferred_result_dtype
63+
):
6264
# GH#38980 groupby agg on extension type fails for non-numeric types
6365
df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1, 4], "B": data_for_grouping})
6466

65-
expected = df.iloc[[0, 2, 4, 7]]
67+
expected_df = pd.DataFrame(
68+
{"A": [1, 1, 2, 2, 3, 3, 1, 4], "B": data_for_grouping}
69+
)
70+
expected = expected_df.iloc[[0, 2, 4, 7]]
6671
expected = expected.set_index("A")
72+
expected["B"] = expected["B"].astype(expected_inferred_result_dtype)
6773

6874
result = df.groupby("A").agg({"B": "first"})
6975
tm.assert_frame_equal(result, expected)

pandas/tests/extension/test_arrow.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,25 @@ def data_for_grouping(dtype):
225225
return pd.array([B, B, None, None, A, A, B, C], dtype=dtype)
226226

227227

228+
@pytest.fixture
229+
def expected_inferred_result_dtype(dtype):
230+
"""
231+
When the data pass through aggregate,
232+
the inferred data type that it will become
233+
234+
"""
235+
236+
pa_dtype = dtype.pyarrow_dtype
237+
if pa.types.is_date(pa_dtype):
238+
return "date32[day][pyarrow]"
239+
elif pa.types.is_time(pa_dtype):
240+
return "time64[us][pyarrow]"
241+
elif pa.types.is_decimal(pa_dtype):
242+
return ArrowDtype(pa.decimal128(4, 3))
243+
else:
244+
return dtype
245+
246+
228247
@pytest.fixture
229248
def data_for_sorting(data_for_grouping):
230249
"""

0 commit comments

Comments
 (0)