Skip to content

Commit b709458

Browse files
authored
Fix tests for cudf 21.10 (#2574)
1 parent 56acab7 commit b709458

File tree

3 files changed

+65
-11
lines changed

3 files changed

+65
-11
lines changed

mars/dataframe/reduction/aggregation.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -710,7 +710,12 @@ def _wrap_df(cls, op, value, index=None):
710710
elif not isinstance(value, xdf.DataFrame):
711711
new_index = None if not op.gpu else getattr(value, "index", None)
712712
dtype = getattr(value, "dtype", None)
713-
value = xdf.DataFrame(value, columns=index, index=new_index)
713+
if xdf is pd:
714+
value = xdf.DataFrame(value, columns=index, index=new_index)
715+
else: # pragma: no cover
716+
value = xdf.DataFrame(value)
717+
value.index = new_index
718+
value.columns = index
714719
else:
715720
return value
716721

mars/serialization/cuda.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
from typing import Any, List, Dict
1616

17+
import pandas as pd
18+
1719
from ..utils import lazy_import
1820
from .core import Serializer, buffered
1921

@@ -49,13 +51,57 @@ def deserialize(self, header: Dict, buffers: List, context: Dict):
4951
class CudfSerializer(Serializer):
5052
serializer_name = "cudf"
5153

54+
@staticmethod
55+
def _get_ext_index_type(index_obj):
56+
import cudf
57+
58+
multi_index_type = None
59+
if isinstance(index_obj, pd.MultiIndex):
60+
multi_index_type = "pandas"
61+
elif isinstance(index_obj, cudf.MultiIndex):
62+
multi_index_type = "cudf"
63+
64+
if multi_index_type is None:
65+
return None
66+
return {
67+
"index_type": multi_index_type,
68+
"names": list(index_obj.names),
69+
}
70+
71+
@staticmethod
72+
def _apply_index_type(obj, attr, header):
73+
import cudf
74+
75+
multi_index_cls = (
76+
pd.MultiIndex if header["index_type"] == "pandas" else cudf.MultiIndex
77+
)
78+
original_index = getattr(obj, attr)
79+
if isinstance(original_index, (pd.MultiIndex, cudf.MultiIndex)):
80+
return
81+
new_index = multi_index_cls.from_tuples(original_index, names=header["names"])
82+
setattr(obj, attr, new_index)
83+
5284
def serialize(self, obj: Any, context: Dict):
53-
return obj.device_serialize()
85+
header, buffers = obj.device_serialize()
86+
if hasattr(obj, "columns"):
87+
header["_ext_columns"] = self._get_ext_index_type(obj.columns)
88+
if hasattr(obj, "index"):
89+
header["_ext_index"] = self._get_ext_index_type(obj.index)
90+
return header, buffers
5491

5592
def deserialize(self, header: Dict, buffers: List, context: Dict):
5693
from cudf.core.abc import Serializable
5794

58-
return Serializable.device_deserialize(header, buffers)
95+
col_header = header.pop("_ext_columns", None)
96+
index_header = header.pop("_ext_index", None)
97+
98+
result = Serializable.device_deserialize(header, buffers)
99+
100+
if col_header is not None:
101+
self._apply_index_type(result, "columns", col_header)
102+
if index_header is not None:
103+
self._apply_index_type(result, "index", index_header)
104+
return result
59105

60106

61107
if cupy is not None:

mars/serialization/tests/test_serial.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -168,15 +168,18 @@ def test_cupy(np_val):
168168

169169
@require_cudf
170170
def test_cudf():
171-
test_df = cudf.DataFrame(
172-
pd.DataFrame(
173-
{
174-
"a": np.random.rand(1000),
175-
"b": np.random.choice(list("abcd"), size=(1000,)),
176-
"c": np.random.randint(0, 100, size=(1000,)),
177-
}
178-
)
171+
raw_df = pd.DataFrame(
172+
{
173+
"a": np.random.rand(1000),
174+
"b": np.random.choice(list("abcd"), size=(1000,)),
175+
"c": np.random.randint(0, 100, size=(1000,)),
176+
}
179177
)
178+
test_df = cudf.DataFrame(raw_df)
179+
cudf.testing.assert_frame_equal(test_df, deserialize(*serialize(test_df)))
180+
181+
raw_df.columns = pd.MultiIndex.from_tuples([("a", "a"), ("a", "b"), ("b", "c")])
182+
test_df = cudf.DataFrame(raw_df)
180183
cudf.testing.assert_frame_equal(test_df, deserialize(*serialize(test_df)))
181184

182185

0 commit comments

Comments
 (0)