|
14 | 14 |
|
15 | 15 | from typing import Any, List, Dict |
16 | 16 |
|
| 17 | +import pandas as pd |
| 18 | + |
17 | 19 | from ..utils import lazy_import |
18 | 20 | from .core import Serializer, buffered |
19 | 21 |
|
@@ -49,13 +51,57 @@ def deserialize(self, header: Dict, buffers: List, context: Dict): |
49 | 51 | class CudfSerializer(Serializer): |
50 | 52 | serializer_name = "cudf" |
51 | 53 |
|
| 54 | + @staticmethod |
| 55 | + def _get_ext_index_type(index_obj): |
| 56 | + import cudf |
| 57 | + |
| 58 | + multi_index_type = None |
| 59 | + if isinstance(index_obj, pd.MultiIndex): |
| 60 | + multi_index_type = "pandas" |
| 61 | + elif isinstance(index_obj, cudf.MultiIndex): |
| 62 | + multi_index_type = "cudf" |
| 63 | + |
| 64 | + if multi_index_type is None: |
| 65 | + return None |
| 66 | + return { |
| 67 | + "index_type": multi_index_type, |
| 68 | + "names": list(index_obj.names), |
| 69 | + } |
| 70 | + |
| 71 | + @staticmethod |
| 72 | + def _apply_index_type(obj, attr, header): |
| 73 | + import cudf |
| 74 | + |
| 75 | + multi_index_cls = ( |
| 76 | + pd.MultiIndex if header["index_type"] == "pandas" else cudf.MultiIndex |
| 77 | + ) |
| 78 | + original_index = getattr(obj, attr) |
| 79 | + if isinstance(original_index, (pd.MultiIndex, cudf.MultiIndex)): |
| 80 | + return |
| 81 | + new_index = multi_index_cls.from_tuples(original_index, names=header["names"]) |
| 82 | + setattr(obj, attr, new_index) |
| 83 | + |
52 | 84 | def serialize(self, obj: Any, context: Dict): |
53 | | - return obj.device_serialize() |
| 85 | + header, buffers = obj.device_serialize() |
| 86 | + if hasattr(obj, "columns"): |
| 87 | + header["_ext_columns"] = self._get_ext_index_type(obj.columns) |
| 88 | + if hasattr(obj, "index"): |
| 89 | + header["_ext_index"] = self._get_ext_index_type(obj.index) |
| 90 | + return header, buffers |
54 | 91 |
|
55 | 92 | def deserialize(self, header: Dict, buffers: List, context: Dict): |
56 | 93 | from cudf.core.abc import Serializable |
57 | 94 |
|
58 | | - return Serializable.device_deserialize(header, buffers) |
| 95 | + col_header = header.pop("_ext_columns", None) |
| 96 | + index_header = header.pop("_ext_index", None) |
| 97 | + |
| 98 | + result = Serializable.device_deserialize(header, buffers) |
| 99 | + |
| 100 | + if col_header is not None: |
| 101 | + self._apply_index_type(result, "columns", col_header) |
| 102 | + if index_header is not None: |
| 103 | + self._apply_index_type(result, "index", index_header) |
| 104 | + return result |
59 | 105 |
|
60 | 106 |
|
61 | 107 | if cupy is not None: |
|
0 commit comments