Skip to content

Commit 4b06c1c

Browse files
fyrestone刘宝
andauthored
Fix duplicate execution (#3301)
* Fix duplicate execute * Fix Co-authored-by: 刘宝 <[email protected]>
1 parent 4b15d0d commit 4b06c1c

File tree

6 files changed

+23
-13
lines changed

6 files changed

+23
-13
lines changed

mars/core/graph/builder/chunk.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def __init__(
124124
self,
125125
tileable_graph: TileableGraph,
126126
tile_context: TileContext,
127-
processed_chunks: Set[ChunkType],
127+
processed_chunks: Set[str],
128128
chunk_to_fetch: Dict[ChunkType, ChunkType],
129129
add_nodes: Callable,
130130
):
@@ -301,11 +301,12 @@ def _iter(self):
301301

302302
if chunk_graph is not None:
303303
# last tiled chunks, add them to processed
304-
# so that fetch chunk can be generated
305-
processed_chunks = [
306-
c.chunk if isinstance(c, FUSE_CHUNK_TYPE) else c
304+
# so that fetch chunk can be generated.
305+
# Use chunk key as the key to make sure the copied chunk can be build to a fetch.
306+
processed_chunks = (
307+
c.chunk.key if isinstance(c, FUSE_CHUNK_TYPE) else c.key
307308
for c in chunk_graph.result_chunks
308-
]
309+
)
309310
self._processed_chunks.update(processed_chunks)
310311

311312
result_chunks = []
@@ -389,7 +390,7 @@ def __init__(
389390
self.tile_context = TileContext() if tile_context is None else tile_context
390391
self.tile_context.set_tileables(set(graph))
391392

392-
self._processed_chunks: Set[ChunkType] = set()
393+
self._processed_chunks: Set[str] = set()
393394
self._chunk_to_fetch: Dict[ChunkType, ChunkType] = dict()
394395

395396
tiler_cls = Tiler if tiler_cls is None else tiler_cls
@@ -402,7 +403,7 @@ def __init__(
402403
)
403404

404405
def _process_node(self, entity: EntityType):
405-
if entity in self._processed_chunks:
406+
if entity.key in self._processed_chunks:
406407
if entity not in self._chunk_to_fetch:
407408
# gen fetch
408409
fetch_chunk = build_fetch(entity).data
@@ -413,7 +414,7 @@ def _process_node(self, entity: EntityType):
413414
def _select_inputs(self, inputs: List[ChunkType]):
414415
new_inputs = []
415416
for inp in inputs:
416-
if inp in self._processed_chunks:
417+
if inp.key in self._processed_chunks:
417418
# gen fetch
418419
if inp not in self._chunk_to_fetch:
419420
fetch_chunk = build_fetch(inp).data
@@ -424,7 +425,7 @@ def _select_inputs(self, inputs: List[ChunkType]):
424425
return new_inputs
425426

426427
def _if_add_node(self, node: EntityType, visited: Set):
427-
return node not in visited and node not in self._processed_chunks
428+
return node not in visited and node.key not in self._processed_chunks
428429

429430
def _build(self) -> Iterable[Union[TileableGraph, ChunkGraph]]:
430431
tile_iterator = iter(self.tiler)

mars/dataframe/base/rechunk.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def tile(cls, op: "DataFrameRechunk"):
157157
params["dtypes"] = pd.concat([c.dtypes for c in inp_chunks_arr[0]])
158158
if len(inp_slice_chunks) == 1:
159159
c = inp_slice_chunks[0]
160-
cc = c.op.copy().reset_key().new_chunk(c.op.inputs, kws=[params])
160+
cc = c.op.copy().new_chunk(c.op.inputs, kws=[params])
161161
out_chunks.append(cc)
162162
else:
163163
out_chunk = DataFrameConcat(

mars/dataframe/base/tests/test_base_execution.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,15 @@ def test_to_cpu_execution(setup_gpu):
9191

9292

9393
def test_rechunk_execution(setup):
94+
ns = np.random.RandomState(0)
95+
df = pd.DataFrame(ns.rand(100, 10), columns=["a" + str(i) for i in range(10)])
96+
97+
# test rechunk after sort
98+
mdf = DataFrame(df, chunk_size=10)
99+
result = mdf.sort_values("a0").rechunk(chunk_size=10).execute().fetch()
100+
expected = df.sort_values("a0")
101+
pd.testing.assert_frame_equal(result, expected)
102+
94103
data = pd.DataFrame(np.random.rand(8, 10))
95104
df = from_pandas_df(pd.DataFrame(data), chunk_size=3)
96105
df2 = df.rechunk((3, 4))

mars/dataframe/sort/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def _tile_head(cls, op: "DataFrameSortOperand"):
8585
shape = tuple(shape)
8686
concat_params["shape"] = shape
8787
if len(to_combine_chunks) == 1:
88-
c = to_combine_chunks[0]
88+
c = to_combine_chunks[0].copy()
8989
c._index = chunk_index
9090
else:
9191
c = DataFrameConcat(

mars/dataframe/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1412,7 +1412,7 @@ def _concat_chunks(merge_chunks: List[ChunkType], output_index: int):
14121412
# concat previous chunks
14131413
if len(to_merge_chunks) == 1:
14141414
# do not generate concat op for 1 input.
1415-
c = to_merge_chunks[0]
1415+
c = to_merge_chunks[0].copy()
14161416
c._index = (
14171417
(len(n_split),) if df_or_series.ndim == 1 else (len(n_split), 0)
14181418
)

mars/services/task/supervisor/preprocessor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(
3636
self,
3737
tileable_graph: TileableGraph,
3838
tile_context: TileContext,
39-
processed_chunks: Set[ChunkType],
39+
processed_chunks: Set[str],
4040
chunk_to_fetch: Dict[ChunkType, ChunkType],
4141
add_nodes: Callable,
4242
cancelled: asyncio.Event = None,

0 commit comments

Comments
 (0)