Skip to content

Commit 2966c8e

Browse files
author
Xuye (Chris) Qin
authored
Support reporting tile progress (#2954)
1 parent 3278699 commit 2966c8e

File tree

23 files changed

+774
-455
lines changed

23 files changed

+774
-455
lines changed

mars/core/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,5 +61,7 @@
6161
ChunkGraph,
6262
TileableGraphBuilder,
6363
ChunkGraphBuilder,
64+
TileContext,
65+
TileStatus,
6466
)
6567
from .mode import enter_mode, is_build_mode, is_eager_mode, is_kernel_mode

mars/core/base.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from functools import wraps
16-
from typing import Dict
16+
from typing import Dict, Tuple, Type
1717

1818
from ..serialization.core import Placeholder, fast_id
1919
from ..serialization.serializables import Serializable, StringField
@@ -117,6 +117,15 @@ def key(self):
117117
def id(self):
118118
return self._id
119119

120+
def to_kv(self, exclude_fields: Tuple[str], accept_value_types: Tuple[Type]):
121+
fields = self._FIELDS
122+
field_values = self._FIELD_VALUES
123+
return {
124+
fields[attr_name].tag: value
125+
for attr_name, value in field_values.items()
126+
if attr_name not in exclude_fields and isinstance(value, accept_value_types)
127+
}
128+
120129

121130
def buffered_base(func):
122131
@wraps(func)

mars/core/entity/utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,12 @@ def refresh_tileable_shape(tileable):
2828

2929

3030
def tile(tileable, *tileables: TileableType):
31-
from ..graph import TileableGraph, TileableGraphBuilder, ChunkGraphBuilder
31+
from ..graph import (
32+
TileableGraph,
33+
TileableGraphBuilder,
34+
ChunkGraphBuilder,
35+
TileContext,
36+
)
3237

3338
raw_tileables = target_tileables = [tileable] + list(tileables)
3439
target_tileables = [t.data if hasattr(t, "data") else t for t in target_tileables]
@@ -38,7 +43,7 @@ def tile(tileable, *tileables: TileableType):
3843
next(tileable_graph_builder.build())
3944

4045
# tile
41-
tile_context = dict()
46+
tile_context = TileContext()
4247
chunk_graph_builder = ChunkGraphBuilder(
4348
tileable_graph, fuse_enabled=False, tile_context=tile_context
4449
)

mars/core/graph/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .builder import TileableGraphBuilder, ChunkGraphBuilder
15+
from .builder import TileableGraphBuilder, ChunkGraphBuilder, TileContext, TileStatus
1616
from .core import DirectedGraph, DAG, GraphContainsCycleError
1717
from .entity import TileableGraph, ChunkGraph, EntityGraph

mars/core/graph/builder/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,5 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .chunk import ChunkGraphBuilder
15+
from .chunk import ChunkGraphBuilder, TileContext, TileStatus
1616
from .tileable import TileableGraphBuilder

mars/core/graph/builder/chunk.py

Lines changed: 110 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import dataclasses
16+
import functools
1617
from typing import (
1718
Callable,
1819
Dict,
@@ -35,6 +36,7 @@
3536

3637

3738
tile_gen_type = Generator[List[ChunkType], List[ChunkType], List[TileableType]]
39+
DEFAULT_UPDATED_PROGRESS = 0.4
3840

3941

4042
@dataclasses.dataclass
@@ -44,14 +46,84 @@ class _TileableHandler:
4446
last_need_processes: List[EntityType] = None
4547

4648

49+
@dataclasses.dataclass
50+
class _TileableTileInfo:
51+
curr_iter: int
52+
# incremental progress for this iteration
53+
tile_progress: float
54+
# newly generated chunks by a tileable in this iteration
55+
generated_chunks: List[ChunkType] = dataclasses.field(default_factory=list)
56+
57+
58+
class TileContext(Dict[TileableType, TileableType]):
59+
_tileables = Set[TileableType]
60+
_tileable_to_progress: Dict[TileableType, float]
61+
_tileable_to_tile_infos: Dict[TileableType, List[_TileableTileInfo]]
62+
63+
def __init__(self, *args, **kw):
64+
super().__init__(*args, **kw)
65+
self._tileables = None
66+
self._tileable_to_progress = dict()
67+
self._tileable_to_tile_infos = dict()
68+
69+
def set_tileables(self, tileables: Set[TileableType]):
70+
self._tileables = tileables
71+
72+
def __setitem__(self, key, value):
73+
self._tileable_to_progress.pop(key, None)
74+
return super().__setitem__(key, value)
75+
76+
def set_progress(self, tileable: TileableType, progress: float):
77+
assert 0.0 <= progress <= 1.0
78+
last_progress = self._tileable_to_progress.get(tileable, 0.0)
79+
self._tileable_to_progress[tileable] = max(progress, last_progress)
80+
81+
def get_progress(self, tileable: TileableType) -> float:
82+
if tileable in self:
83+
return 1.0
84+
else:
85+
return self._tileable_to_progress.get(tileable, 0.0)
86+
87+
def get_all_progress(self) -> float:
88+
return sum(self.get_progress(t) for t in self._tileables) / len(self._tileables)
89+
90+
def record_tileable_tile_info(
91+
self, tileable: TileableType, curr_iter: int, generated_chunks: List[ChunkType]
92+
):
93+
if tileable not in self._tileable_to_tile_infos:
94+
self._tileable_to_tile_infos[tileable] = []
95+
prev_progress = sum(
96+
info.tile_progress for info in self._tileable_to_tile_infos[tileable]
97+
)
98+
curr_progress = self.get_progress(tileable)
99+
infos = self._tileable_to_tile_infos[tileable]
100+
infos.append(
101+
_TileableTileInfo(
102+
curr_iter=curr_iter,
103+
tile_progress=curr_progress - prev_progress,
104+
generated_chunks=generated_chunks,
105+
)
106+
)
107+
108+
def get_tileable_tile_infos(self) -> Dict[TileableType, List[_TileableTileInfo]]:
109+
return {t: self._tileable_to_tile_infos.get(t, list()) for t in self._tileables}
110+
111+
112+
@dataclasses.dataclass
113+
class TileStatus:
114+
entities: List[EntityType] = None
115+
progress: float = None
116+
117+
47118
class Tiler:
119+
_cur_iter: int
48120
_cur_chunk_graph: Optional[ChunkGraph]
49121
_tileable_handlers: Iterable[_TileableHandler]
50122

51123
def __init__(
52124
self,
53125
tileable_graph: TileableGraph,
54-
tile_context: Dict[TileableType, TileableType],
126+
tile_context: TileContext,
55127
processed_chunks: Set[ChunkType],
56128
chunk_to_fetch: Dict[ChunkType, ChunkType],
57129
add_nodes: Callable,
@@ -60,13 +132,31 @@ def __init__(
60132
self._tile_context = tile_context
61133
self._processed_chunks = processed_chunks
62134
self._chunk_to_fetch = chunk_to_fetch
63-
self._add_nodes = add_nodes
135+
self._add_nodes = self._wrap_add_nodes(add_nodes)
136+
self._curr_iter = 0
64137
self._cur_chunk_graph = None
65138
self._tileable_handlers = (
66139
_TileableHandler(tileable, self._tile_handler(tileable))
67140
for tileable in tileable_graph.topological_iter()
68141
)
69142

143+
def _wrap_add_nodes(self, add_nodes: Callable):
144+
@functools.wraps(add_nodes)
145+
def inner(
146+
chunk_graph: ChunkGraph,
147+
chunks: List[ChunkType],
148+
visited: Set[ChunkType],
149+
tileable: TileableType,
150+
):
151+
prev_chunks = set(chunk_graph)
152+
add_nodes(chunk_graph, chunks, visited)
153+
new_chunks = set(chunk_graph)
154+
self._tile_context.record_tileable_tile_info(
155+
tileable, self._curr_iter, list(new_chunks - prev_chunks)
156+
)
157+
158+
return inner
159+
70160
@staticmethod
71161
def _get_data(entity: EntityType):
72162
return entity.data if hasattr(entity, "data") else entity
@@ -119,6 +209,17 @@ def _tile(
119209
):
120210
try:
121211
need_process = next(tile_handler)
212+
213+
if isinstance(need_process, TileStatus):
214+
# process tile that returns progress
215+
self._tile_context.set_progress(tileable, need_process.progress)
216+
need_process = need_process.entities
217+
else:
218+
# if progress not specified, we just update 0.4 * rest progress
219+
progress = self._tile_context.get_progress(tileable)
220+
new_progress = progress + (1.0 - progress) * DEFAULT_UPDATED_PROGRESS
221+
self._tile_context.set_progress(tileable, new_progress)
222+
122223
chunks = []
123224
if need_process is not None:
124225
for t in need_process:
@@ -127,7 +228,7 @@ def _tile(
127228
elif isinstance(t, TILEABLE_TYPE):
128229
to_update_tileables.append(self._get_data(t))
129230
# not finished yet
130-
self._add_nodes(chunk_graph, chunks.copy(), visited)
231+
self._add_nodes(chunk_graph, chunks.copy(), visited, tileable)
131232
next_tileable_handlers.append(
132233
_TileableHandler(tileable, tile_handler, need_process)
133234
)
@@ -145,8 +246,8 @@ def _tile(
145246
if chunks is None: # pragma: no cover
146247
raise ValueError(f"tileable({out}) is still coarse after tile")
147248
chunks = [self._get_data(c) for c in chunks]
148-
self._add_nodes(chunk_graph, chunks, visited)
149249
self._tile_context[out] = tiled_tileable
250+
self._add_nodes(chunk_graph, chunks, visited, tileable)
150251

151252
def _gen_result_chunks(
152253
self,
@@ -227,6 +328,8 @@ def _iter(self):
227328
# prune unused chunks
228329
prune_chunk_graph(chunk_graph)
229330

331+
self._curr_iter += 1
332+
230333
return to_update_tileables
231334

232335
def __iter__(self):
@@ -278,12 +381,13 @@ def __init__(
278381
self,
279382
graph: TileableGraph,
280383
fuse_enabled: bool = True,
281-
tile_context: Dict[TileableType, TileableType] = None,
384+
tile_context: TileContext = None,
282385
tiler_cls: Union[Type[Tiler], Callable] = None,
283386
):
284387
super().__init__(graph)
285388
self.fuse_enabled = fuse_enabled
286-
self.tile_context = dict() if tile_context is None else tile_context
389+
self.tile_context = TileContext() if tile_context is None else tile_context
390+
self.tile_context.set_tileables(set(graph))
287391

288392
self._processed_chunks: Set[ChunkType] = set()
289393
self._chunk_to_fetch: Dict[ChunkType, ChunkType] = dict()

mars/dataframe/merge/merge.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import pandas as pd
2222

2323
from ... import opcodes as OperandDef
24-
from ...core import OutputType, recursive_tile
24+
from ...core import OutputType, recursive_tile, TileStatus
2525
from ...core.context import get_context
2626
from ...core.operand import OperandStage, MapReduceOperand
2727
from ...serialization.serializables import (
@@ -609,7 +609,7 @@ def tile(cls, op: "DataFrameMerge"):
609609
auto_merge_before
610610
and len(left.chunks) + len(right.chunks) > auto_merge_threshold
611611
):
612-
yield [left, right] + left.chunks + right.chunks
612+
yield TileStatus([left, right] + left.chunks + right.chunks, progress=0.2)
613613
left = auto_merge_chunks(ctx, left)
614614
right = auto_merge_chunks(ctx, right)
615615

@@ -626,7 +626,7 @@ def tile(cls, op: "DataFrameMerge"):
626626
right_on = _prepare_shuffle_on(op.right_index, op.right_on, op.on)
627627
if op.how == "inner" and op.bloom_filter:
628628
if has_unknown_shape(left, right):
629-
yield left.chunks + right.chunks
629+
yield TileStatus(left.chunks + right.chunks, progress=0.3)
630630
small_one = right if len(left.chunks) > len(right.chunks) else left
631631
logger.debug(
632632
"Apply bloom filter for operand %s, use DataFrame %s to build bloom filter.",
@@ -637,7 +637,9 @@ def tile(cls, op: "DataFrameMerge"):
637637
*cls._apply_bloom_filter(left, right, left_on, right_on, op)
638638
)
639639
# auto merge after bloom filter
640-
yield [left, right] + left.chunks + right.chunks
640+
yield TileStatus(
641+
[left, right] + left.chunks + right.chunks, progress=0.5
642+
)
641643
left = auto_merge_chunks(ctx, left)
642644
right = auto_merge_chunks(ctx, right)
643645

@@ -660,7 +662,9 @@ def tile(cls, op: "DataFrameMerge"):
660662
):
661663
# if how=="inner", output data size will reduce greatly with high probability,
662664
# use auto_merge_chunks to combine small chunks.
663-
yield ret[0].chunks # trigger execution for chunks
665+
yield TileStatus(
666+
ret[0].chunks, progress=0.8
667+
) # trigger execution for chunks
664668
return [auto_merge_chunks(get_context(), ret[0])]
665669
else:
666670
return ret

mars/dataframe/merge/tests/test_merge_execution.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import numpy as np
1616
import pandas as pd
17+
import pytest
1718

1819
from ....core.graph.builder.utils import build_graph
1920
from ...datasource.dataframe import from_pandas
@@ -597,7 +598,8 @@ def test_merge_with_bloom_filter(setup):
597598
)
598599

599600

600-
def test_merge_on_duplicate_columns(setup):
601+
@pytest.mark.parametrize("auto_merge", ["none", "both", "before", "after"])
602+
def test_merge_on_duplicate_columns(setup, auto_merge):
601603
raw1 = pd.DataFrame(
602604
[["foo", 1, "bar"], ["bar", 2, "foo"], ["baz", 3, "foo"]],
603605
columns=["lkey", "value", "value"],
@@ -611,7 +613,7 @@ def test_merge_on_duplicate_columns(setup):
611613
df1 = from_pandas(raw1, chunk_size=2)
612614
df2 = from_pandas(raw2, chunk_size=3)
613615

614-
r = df1.merge(df2, left_on="lkey", right_on="rkey", auto_merge="none")
616+
r = df1.merge(df2, left_on="lkey", right_on="rkey", auto_merge=auto_merge)
615617
result = r.execute().fetch()
616618
expected = raw1.merge(raw2, left_on="lkey", right_on="rkey")
617619
pd.testing.assert_frame_equal(expected, result)

mars/optimization/logical/chunk/tests/test_column_pruning.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,13 @@
1919
import pytest
2020

2121
from ..... import dataframe as md
22-
from .....core import enter_mode, TileableGraph, TileableGraphBuilder, ChunkGraphBuilder
22+
from .....core import (
23+
enter_mode,
24+
TileableGraph,
25+
TileableGraphBuilder,
26+
ChunkGraphBuilder,
27+
TileContext,
28+
)
2329
from .. import optimize
2430

2531

@@ -47,7 +53,7 @@ def test_groupby_read_csv(gen_data1):
4753
df2 = df1[["a", "b"]]
4854
graph = TileableGraph([df2.data])
4955
next(TileableGraphBuilder(graph).build())
50-
context = dict()
56+
context = TileContext()
5157
chunk_graph_builder = ChunkGraphBuilder(
5258
graph, fuse_enabled=False, tile_context=context
5359
)

mars/optimization/logical/chunk/tests/test_head.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,13 @@
1919
import pytest
2020

2121
from ..... import dataframe as md
22-
from .....core import enter_mode, TileableGraph, TileableGraphBuilder, ChunkGraphBuilder
22+
from .....core import (
23+
enter_mode,
24+
TileableGraph,
25+
TileableGraphBuilder,
26+
ChunkGraphBuilder,
27+
TileContext,
28+
)
2329
from .. import optimize
2430

2531

@@ -47,7 +53,7 @@ def test_read_csv_head(gen_data1):
4753
df2 = df1.head(5)
4854
graph = TileableGraph([df2.data])
4955
next(TileableGraphBuilder(graph).build())
50-
context = dict()
56+
context = TileContext()
5157
chunk_graph_builder = ChunkGraphBuilder(
5258
graph, fuse_enabled=False, tile_context=context
5359
)

0 commit comments

Comments
 (0)