Skip to content

Commit 3b0130e

Browse files
authored
Shuffle both sides at the same time for md.merge (#3041)
1 parent 3f9fb48 commit 3b0130e

File tree

2 files changed

+132
-63
lines changed

2 files changed

+132
-63
lines changed

mars/dataframe/merge/merge.py

Lines changed: 122 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@
3535
NamedTupleField,
3636
)
3737
from ...typing import TileableType
38-
from ...utils import has_unknown_shape
38+
from ...utils import has_unknown_shape, lazy_import
3939
from ..base.bloom_filter import filter_by_bloom_filter
40-
from ..core import DataFrame, Series
40+
from ..core import DataFrame, Series, DataFrameChunk
4141
from ..operands import DataFrameOperand, DataFrameOperandMixin, DataFrameShuffleProxy
4242
from ..utils import (
4343
auto_merge_chunks,
@@ -46,6 +46,7 @@
4646
parse_index,
4747
hash_dataframe_on,
4848
infer_index_value,
49+
is_cudf,
4950
)
5051

5152
import logging
@@ -63,34 +64,29 @@
6364
BLOOM_FILTER_ON_OPTIONS = ["large", "small", "both"]
6465
DEFAULT_BLOOM_FILTER_ON = "large"
6566

67+
cudf = lazy_import("cudf")
68+
6669

6770
class DataFrameMergeAlign(MapReduceOperand, DataFrameOperandMixin):
6871
_op_type_ = OperandDef.DATAFRAME_SHUFFLE_MERGE_ALIGN
6972

70-
_index_shuffle_size = Int32Field("index_shuffle_size")
71-
_shuffle_on = AnyField("shuffle_on")
72-
73-
_input = KeyField("input")
73+
index_shuffle_size = Int32Field("index_shuffle_size")
74+
shuffle_on = AnyField("shuffle_on")
7475

75-
def __init__(self, index_shuffle_size=None, shuffle_on=None, **kw):
76-
super().__init__(
77-
_index_shuffle_size=index_shuffle_size,
78-
_shuffle_on=shuffle_on,
79-
_output_types=[OutputType.dataframe],
80-
**kw,
81-
)
76+
input = KeyField("input")
8277

83-
@property
84-
def index_shuffle_size(self):
85-
return self._index_shuffle_size
78+
def __init__(self, output_types=None, **kw):
79+
super().__init__(_output_types=output_types, **kw)
80+
if output_types is None:
81+
if self.stage == OperandStage.map:
82+
output_types = [OutputType.dataframe]
83+
elif self.stage == OperandStage.reduce:
84+
output_types = [OutputType.dataframe] * 2
85+
self._output_types = output_types
8686

8787
@property
88-
def shuffle_on(self):
89-
return self._shuffle_on
90-
91-
def _set_inputs(self, inputs):
92-
super()._set_inputs(inputs)
93-
self._input = self._inputs[0]
88+
def output_limit(self) -> int:
89+
return len(self.output_types)
9490

9591
@classmethod
9692
def execute_map(cls, ctx, op):
@@ -123,16 +119,18 @@ def execute_map(cls, ctx, op):
123119

124120
@classmethod
125121
def execute_reduce(cls, ctx, op: "DataFrameMergeAlign"):
126-
chunk = op.outputs[0]
127-
input_idx_to_df = dict(op.iter_mapper_data_with_index(ctx, skip_none=True))
128-
row_idxes = sorted({idx[0] for idx in input_idx_to_df})
129-
130-
res = []
131-
for row_idx in row_idxes:
132-
row_df = input_idx_to_df.get((row_idx, 0), None)
133-
if row_df is not None:
134-
res.append(row_df)
135-
ctx[chunk.key] = pd.concat(res, axis=0)
122+
for i, chunk in enumerate(op.outputs):
123+
input_idx_to_df = dict(
124+
op.iter_mapper_data_with_index(ctx, mapper_id=i, skip_none=True)
125+
)
126+
row_idxes = sorted({idx[0] for idx in input_idx_to_df})
127+
res = []
128+
for row_idx in row_idxes:
129+
row_df = input_idx_to_df.get((row_idx, 0), None)
130+
if row_df is not None:
131+
res.append(row_df)
132+
xdf = cudf if is_cudf(res[0]) else pd
133+
ctx[chunk.key] = xdf.concat(res, axis=0)
136134

137135
@classmethod
138136
def execute(cls, ctx, op):
@@ -213,6 +211,30 @@ def __call__(self, left, right):
213211
columns_value=parse_index(merged.columns, store_data=True),
214212
)
215213

214+
@classmethod
215+
def _gen_map_chunk(
216+
cls,
217+
chunk: DataFrameChunk,
218+
shuffle_on: Union[List, str],
219+
out_size: int,
220+
mapper_id: int = 0,
221+
):
222+
map_op = DataFrameMergeAlign(
223+
stage=OperandStage.map,
224+
shuffle_on=shuffle_on,
225+
sparse=chunk.issparse(),
226+
mapper_id=mapper_id,
227+
index_shuffle_size=out_size,
228+
)
229+
return map_op.new_chunk(
230+
[chunk],
231+
shape=(np.nan, np.nan),
232+
dtypes=chunk.dtypes,
233+
index=chunk.index,
234+
index_value=chunk.index_value,
235+
columns_value=chunk.columns_value,
236+
)
237+
216238
@classmethod
217239
def _gen_shuffle_chunks(
218240
cls,
@@ -221,24 +243,9 @@ def _gen_shuffle_chunks(
221243
df: Union[DataFrame, Series],
222244
):
223245
# gen map chunks
224-
map_chunks = []
225-
for chunk in df.chunks:
226-
map_op = DataFrameMergeAlign(
227-
stage=OperandStage.map,
228-
shuffle_on=shuffle_on,
229-
sparse=chunk.issparse(),
230-
index_shuffle_size=out_shape[0],
231-
)
232-
map_chunks.append(
233-
map_op.new_chunk(
234-
[chunk],
235-
shape=(np.nan, np.nan),
236-
dtypes=chunk.dtypes,
237-
index=chunk.index,
238-
index_value=chunk.index_value,
239-
columns_value=chunk.columns_value,
240-
)
241-
)
246+
map_chunks = [
247+
cls._gen_map_chunk(chunk, shuffle_on, out_shape[0]) for chunk in df.chunks
248+
]
242249

243250
proxy_chunk = DataFrameShuffleProxy(
244251
output_types=[OutputType.dataframe]
@@ -254,7 +261,9 @@ def _gen_shuffle_chunks(
254261
reduce_chunks = []
255262
for out_idx in itertools.product(*(range(s) for s in out_shape)):
256263
reduce_op = DataFrameMergeAlign(
257-
stage=OperandStage.reduce, sparse=proxy_chunk.issparse()
264+
stage=OperandStage.reduce,
265+
sparse=proxy_chunk.issparse(),
266+
output_types=[OutputType.dataframe],
258267
)
259268
reduce_chunks.append(
260269
reduce_op.new_chunk(
@@ -268,6 +277,65 @@ def _gen_shuffle_chunks(
268277
)
269278
return reduce_chunks
270279

280+
@classmethod
281+
def _gen_both_shuffle_chunks(
282+
cls,
283+
out_shape: Tuple,
284+
left_shuffle_on: Union[List, str],
285+
right_shuffle_on: Union[List, str],
286+
left: Union[DataFrame, Series],
287+
right: Union[DataFrame, Series],
288+
):
289+
# gen map chunks
290+
# for left dataframe, use 0 as mapper_id
291+
left_map_chunks = [
292+
cls._gen_map_chunk(chunk, left_shuffle_on, out_shape[0], mapper_id=0)
293+
for chunk in left.chunks
294+
]
295+
# for right dataframe, use 1 as mapper_id
296+
right_map_chunks = [
297+
cls._gen_map_chunk(chunk, right_shuffle_on, out_shape[0], mapper_id=1)
298+
for chunk in right.chunks
299+
]
300+
map_chunks = left_map_chunks + right_map_chunks
301+
302+
proxy_chunk = DataFrameShuffleProxy(
303+
output_types=[OutputType.dataframe]
304+
).new_chunk(
305+
map_chunks,
306+
shape=(),
307+
dtypes=left.dtypes,
308+
index_value=left.index_value,
309+
columns_value=left.columns_value,
310+
)
311+
312+
# gen reduce chunks
313+
left_reduce_chunks = []
314+
right_reduce_chunks = []
315+
for out_idx in itertools.product(*(range(s) for s in out_shape)):
316+
reduce_op = DataFrameMergeAlign(
317+
stage=OperandStage.reduce, sparse=proxy_chunk.issparse()
318+
)
319+
left_param = {
320+
"shape": (np.nan, np.nan),
321+
"dtypes": left.dtypes,
322+
"index": out_idx,
323+
"index_value": left.index_value,
324+
"columns_value": left.columns_value,
325+
}
326+
right_param = {
327+
"shape": (np.nan, np.nan),
328+
"dtypes": right.dtypes,
329+
"index": out_idx,
330+
"index_value": right.index_value,
331+
"columns_value": right.columns_value,
332+
}
333+
params = [left_param, right_param]
334+
left_reduce, right_reduce = reduce_op.new_chunks([proxy_chunk], kws=params)
335+
left_reduce_chunks.append(left_reduce)
336+
right_reduce_chunks.append(right_reduce)
337+
return left_reduce_chunks, right_reduce_chunks
338+
271339
@classmethod
272340
def _apply_bloom_filter(
273341
cls,
@@ -404,8 +472,9 @@ def _tile_shuffle(
404472
right_on = _prepare_shuffle_on(op.right_index, op.right_on, op.on)
405473

406474
# do shuffle
407-
left_chunks = cls._gen_shuffle_chunks(out_chunk_shape, left_on, left)
408-
right_chunks = cls._gen_shuffle_chunks(out_chunk_shape, right_on, right)
475+
left_chunks, right_chunks = cls._gen_both_shuffle_chunks(
476+
out_chunk_shape, left_on, right_on, left, right
477+
)
409478

410479
out_chunks = []
411480
for left_chunk, right_chunk in zip(left_chunks, right_chunks):

mars/dataframe/merge/tests/test_merge.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,16 +55,16 @@ def test_merge():
5555
assert left.op.stage == OperandStage.reduce
5656
assert isinstance(right.op, DataFrameMergeAlign)
5757
assert right.op.stage == OperandStage.reduce
58-
assert len(left.inputs[0].inputs) == 2
59-
assert len(right.inputs[0].inputs) == 2
60-
for lchunk in left.inputs[0].inputs:
58+
assert len(left.inputs[0].inputs) == 4
59+
assert len(right.inputs[0].inputs) == 4
60+
for lchunk in left.inputs[0].inputs[:2]:
6161
assert isinstance(lchunk.op, DataFrameMergeAlign)
6262
assert lchunk.op.stage == OperandStage.map
6363
assert lchunk.op.index_shuffle_size == 2
6464
assert lchunk.op.shuffle_on == kw.get("on", None) or kw.get(
6565
"left_on", None
6666
)
67-
for rchunk in right.inputs[0].inputs:
67+
for rchunk in right.inputs[0].inputs[2:]:
6868
assert isinstance(rchunk.op, DataFrameMergeAlign)
6969
assert rchunk.op.stage == OperandStage.map
7070
assert rchunk.op.index_shuffle_size == 2
@@ -127,8 +127,8 @@ def test_join():
127127
assert left.op.stage == OperandStage.reduce
128128
assert isinstance(right.op, DataFrameMergeAlign)
129129
assert right.op.stage == OperandStage.reduce
130-
assert len(left.inputs[0].inputs) == 2
131-
assert len(right.inputs[0].inputs) == 3
130+
assert len(left.inputs[0].inputs) == 5
131+
assert len(right.inputs[0].inputs) == 5
132132
for lchunk in left.inputs[0].inputs:
133133
assert isinstance(lchunk.op, DataFrameMergeAlign)
134134
assert lchunk.op.stage == OperandStage.map
@@ -175,14 +175,14 @@ def test_join_on():
175175
assert left.op.stage == OperandStage.reduce
176176
assert isinstance(right.op, DataFrameMergeAlign)
177177
assert right.op.stage == OperandStage.reduce
178-
assert len(left.inputs[0].inputs) == 2
179-
assert len(right.inputs[0].inputs) == 3
180-
for lchunk in left.inputs[0].inputs:
178+
assert len(left.inputs[0].inputs) == 5
179+
assert len(right.inputs[0].inputs) == 5
180+
for lchunk in left.inputs[0].inputs[:2]:
181181
assert isinstance(lchunk.op, DataFrameMergeAlign)
182182
assert lchunk.op.stage == OperandStage.map
183183
assert lchunk.op.index_shuffle_size == 3
184184
assert lchunk.op.shuffle_on == kw.get("on", None)
185-
for rchunk in right.inputs[0].inputs:
185+
for rchunk in right.inputs[0].inputs[2:]:
186186
assert isinstance(rchunk.op, DataFrameMergeAlign)
187187
assert rchunk.op.stage == OperandStage.map
188188
assert rchunk.op.index_shuffle_size == 3

0 commit comments

Comments
 (0)