|
17 | 17 | from collections import defaultdict |
18 | 18 | from typing import Dict, Iterable, List, Tuple |
19 | 19 |
|
20 | | -from ....utils import ceildiv, lazy_import |
| 20 | +from ....utils import lazy_import |
21 | 21 |
|
22 | 22 | ray = lazy_import("ray") |
23 | 23 | parallel_it = lazy_import("ray.util.iter") |
@@ -71,43 +71,37 @@ def _group_chunk_refs( |
71 | 71 | return group_to_obj_refs |
72 | 72 |
|
73 | 73 |
|
74 | | -def _create_ml_dataset(name: str, group_to_obj_refs: Dict[str, List["ray.ObjectRef"]]): |
75 | | - record_batches = [] |
76 | | - for rank, obj_refs in enumerate(group_to_obj_refs.values()): |
77 | | - record_batches.append(ChunkRefBatch(shard_id=rank, obj_refs=obj_refs)) |
78 | | - worker_cls = ray.remote(num_cpus=0)(parallel_it.ParallelIteratorWorker) |
79 | | - actors = [worker_cls.remote(g, False) for g in record_batches] |
80 | | - it = parallel_it.from_actors(actors, name) |
81 | | - ds = ml_dataset.from_parallel_iter( |
82 | | - it, need_convert=False, batch_size=0, repeated=False |
83 | | - ) |
84 | | - return ds |
| 74 | +def _rechunk_if_needed(df, num_shards: int = None): |
| 75 | + try: |
| 76 | + if num_shards: |
| 77 | + assert isinstance(num_shards, int) and num_shards > 0 |
| 78 | + df = df.rebalance(axis=0, num_partitions=num_shards) |
| 79 | + df = df.rechunk({1: df.shape[1]}) |
| 80 | + df = df.reset_index(drop=True) |
| 81 | + return df.execute() |
| 82 | + except Exception as e: # pragma: no cover |
| 83 | + raise Exception(f"rechunk failed df.shape {df.shape}") from e |
85 | 84 |
|
86 | 85 |
|
87 | | -def _rechunk_if_needed(df, num_shards: int = None): |
88 | | - chunk_size = df.extra_params.raw_chunk_size or max(df.shape) |
89 | | - num_rows = df.shape[0] |
90 | | - num_columns = df.shape[1] |
91 | | - # if chunk size not set, num_chunks_in_row = 1 |
92 | | - # if chunk size is set more than max(df.shape), num_chunks_in_row = 1 |
93 | | - # otherwise, num_chunks_in_row depends on ceildiv(num_rows, chunk_size) |
94 | | - num_chunks_in_row = ceildiv(num_rows, chunk_size) |
95 | | - naive_num_partitions = ceildiv(num_rows, num_columns) |
96 | | - |
97 | | - need_re_execute = False |
98 | | - # ensure each part holds all columns |
99 | | - if chunk_size < num_columns: |
100 | | - df = df.rebalance(axis=1, num_partitions=1) |
101 | | - need_re_execute = True |
102 | | - if num_shards and num_chunks_in_row < num_shards: |
103 | | - df = df.rebalance(axis=0, num_partitions=num_shards) |
104 | | - need_re_execute = True |
105 | | - if not num_shards and num_chunks_in_row == 1: |
106 | | - df = df.rebalance(axis=0, num_partitions=naive_num_partitions) |
107 | | - need_re_execute = True |
108 | | - if need_re_execute: |
109 | | - df.execute() |
110 | | - return df |
| 86 | +if ray: |
| 87 | + |
| 88 | + class _MLDataset(ml_dataset.MLDataset): |
| 89 | + def __init__(self, mars_dataframe, actor_sets, name: str, parent_iterators): |
| 90 | + super().__init__(actor_sets, name, parent_iterators, 0, False) |
| 91 | + # Hold mars dataframe to avoid mars dataframe and ray object gc. |
| 92 | + # TODO(mubai) Use a separate operator for rechunk and avoiding gc. |
| 93 | + self._mars_dataframe = mars_dataframe |
| 94 | + |
| 95 | + def __getstate__(self): |
| 96 | + state = self.__dict__.copy() |
| 97 | + state.pop("_mars_dataframe", None) |
| 98 | + return state |
| 99 | + |
| 100 | + # The default __setstate__ will update _MLDataset's __dict__; |
| 101 | + |
| 102 | + |
| 103 | +else: |
| 104 | + _MLDataset = None |
111 | 105 |
|
112 | 106 |
|
113 | 107 | def to_ray_mldataset(df, num_shards: int = None): |
@@ -139,4 +133,11 @@ def to_ray_mldataset(df, num_shards: int = None): |
139 | 133 | group_to_obj_refs: Dict[str, List[ray.ObjectRef]] = _group_chunk_refs( |
140 | 134 | chunk_addr_refs, num_shards |
141 | 135 | ) |
142 | | - return _create_ml_dataset("from_mars", group_to_obj_refs) |
| 136 | + |
| 137 | + record_batches = [] |
| 138 | + for rank, obj_refs in enumerate(group_to_obj_refs.values()): |
| 139 | + record_batches.append(ChunkRefBatch(shard_id=rank, obj_refs=obj_refs)) |
| 140 | + worker_cls = ray.remote(num_cpus=0)(parallel_it.ParallelIteratorWorker) |
| 141 | + actors = [worker_cls.remote(g, False) for g in record_batches] |
| 142 | + it = parallel_it.from_actors(actors, "from_mars") |
| 143 | + return _MLDataset(df, it.actor_sets, it.name, it.parent_iterators) |
0 commit comments