|
| 1 | +import itertools |
1 | 2 | import operator |
2 | 3 |
|
3 | 4 | import pyarrow as pa |
|
30 | 31 | ) |
31 | 32 | from pandas.tests.extension.base.printing import BasePrintingTests |
32 | 33 | from pandas.tests.extension.base.reduce import BaseReduceTests |
| 34 | +from pandas.tests.extension.base.reshaping import BaseReshapingTests |
33 | 35 |
|
34 | 36 | # TODO(wayd): This is copied from string tests - is it required here? |
35 | 37 | # @pytest.fixture(params=[True, False]) |
@@ -83,7 +85,7 @@ class TestListArray( |
83 | 85 | BaseUnaryOpsTests, |
84 | 86 | BasePrintingTests, |
85 | 87 | BaseReduceTests, |
86 | | - # BaseReshapingTests, |
| 88 | + BaseReshapingTests, |
87 | 89 | # BaseSetitemTests, |
88 | 90 | Dim2CompatTests, |
89 | 91 | ): |
@@ -159,6 +161,73 @@ def test_compare_array(self, data, comparison_op): |
159 | 161 | def test_invert(self, data): |
160 | 162 | pytest.skip("ListArray does not implement invert") |
161 | 163 |
|
| 164 | + def test_merge_on_extension_array(self, data): |
| 165 | + pytest.skip("ListArray cannot be factorized") |
| 166 | + |
| 167 | + def test_merge_on_extension_array_duplicates(self, data): |
| 168 | + pytest.skip("ListArray cannot be factorized") |
| 169 | + |
| 170 | + @pytest.mark.parametrize( |
| 171 | + "index", |
| 172 | + [ |
| 173 | + # Two levels, uniform. |
| 174 | + pd.MultiIndex.from_product(([["A", "B"], ["a", "b"]]), names=["a", "b"]), |
| 175 | + # non-uniform |
| 176 | + pd.MultiIndex.from_tuples([("A", "a"), ("A", "b"), ("B", "b")]), |
| 177 | + # three levels, non-uniform |
| 178 | + pd.MultiIndex.from_product([("A", "B"), ("a", "b", "c"), (0, 1, 2)]), |
| 179 | + pd.MultiIndex.from_tuples( |
| 180 | + [ |
| 181 | + ("A", "a", 1), |
| 182 | + ("A", "b", 0), |
| 183 | + ("A", "a", 0), |
| 184 | + ("B", "a", 0), |
| 185 | + ("B", "c", 1), |
| 186 | + ] |
| 187 | + ), |
| 188 | + ], |
| 189 | + ) |
| 190 | + @pytest.mark.parametrize("obj", ["series", "frame"]) |
| 191 | + def test_unstack(self, data, index, obj): |
| 192 | + # TODO: the base class test casts everything to object |
| 193 | + # If you remove the object casts, these tests pass... |
| 194 | + # Check if still needed in base class |
| 195 | + data = data[: len(index)] |
| 196 | + if obj == "series": |
| 197 | + ser = pd.Series(data, index=index) |
| 198 | + else: |
| 199 | + ser = pd.DataFrame({"A": data, "B": data}, index=index) |
| 200 | + |
| 201 | + n = index.nlevels |
| 202 | + levels = list(range(n)) |
| 203 | + # [0, 1, 2] |
| 204 | + # [(0,), (1,), (2,), (0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)] |
| 205 | + combinations = itertools.chain.from_iterable( |
| 206 | + itertools.permutations(levels, i) for i in range(1, n) |
| 207 | + ) |
| 208 | + |
| 209 | + for level in combinations: |
| 210 | + result = ser.unstack(level=level) |
| 211 | + assert all( |
| 212 | + isinstance(result[col].array, type(data)) for col in result.columns |
| 213 | + ) |
| 214 | + |
| 215 | + if obj == "series": |
| 216 | + # We should get the same result with to_frame+unstack+droplevel |
| 217 | + df = ser.to_frame() |
| 218 | + |
| 219 | + alt = df.unstack(level=level).droplevel(0, axis=1) |
| 220 | + tm.assert_frame_equal(result, alt) |
| 221 | + |
| 222 | + # obj_ser = ser.astype(object) |
| 223 | + |
| 224 | + expected = ser.unstack(level=level, fill_value=data.dtype.na_value) |
| 225 | + # if obj == "series": |
| 226 | + # assert (expected.dtypes == object).all() |
| 227 | + |
| 228 | + # result = result.astype(object) |
| 229 | + tm.assert_frame_equal(result, expected) |
| 230 | + |
162 | 231 |
|
163 | 232 | def test_to_csv(data): |
164 | 233 | # https://github.com/pandas-dev/pandas/issues/28840 |
|
0 commit comments