Skip to content

Commit a057995

Browse files
authored
Add logic key for tileable graph (#2961)
1 parent 2966c8e commit a057995

File tree

2 files changed

+109
-0
lines changed

2 files changed

+109
-0
lines changed

mars/core/graph/entity.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from ...serialization.core import buffered
2020
from ...serialization.serializables import Serializable, DictField, ListField, BoolField
2121
from ...serialization.serializables.core import SerializableSerializer
22+
from ...utils import tokenize
2223
from .core import DAG
2324

2425

@@ -57,6 +58,11 @@ def copy(self) -> "EntityGraph":
5758

5859
class TileableGraph(EntityGraph, Iterable[Tileable]):
5960
_result_tileables: List[Tileable]
61+
# logic key is a unique and deterministic key for `TileableGraph`. For
62+
# multiple runs the logic key will remain same if the computational logic
63+
# doesn't change. And it can be used to some optimization when running a
64+
# same `execute`, like HBO.
65+
_logic_key: str
6066

6167
def __init__(self, result_tileables: List[Tileable] = None):
6268
super().__init__()
@@ -74,6 +80,19 @@ def results(self):
7480
def results(self, new_results):
7581
self._result_tileables = new_results
7682

83+
@property
84+
def logic_key(self):
85+
if not hasattr(self, "_logic_key") or self._logic_key is None:
86+
token_keys = []
87+
for node in self.bfs():
88+
token_keys.append(
89+
tokenize(node.op.get_logic_key(), **node.extra_params)
90+
if node.extra_params
91+
else node.op.get_logic_key()
92+
)
93+
self._logic_key = tokenize(*token_keys)
94+
return self._logic_key
95+
7796

7897
class ChunkGraph(EntityGraph, Iterable[Chunk]):
7998
_result_chunks: List[Chunk]

mars/core/graph/tests/test_graph.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414

1515
import pytest
1616

17+
import numpy as np
18+
19+
from .... import dataframe as md
1720
from .... import tensor as mt
1821
from ....tests.core import flaky
1922
from ....utils import to_str
@@ -105,3 +108,90 @@ def test_to_dot():
105108

106109
dot = to_str(graph.to_dot(trunc_key=5))
107110
assert all(to_str(n.key)[:5] in dot for n in graph) is True
111+
112+
113+
def test_tileable_graph_logic_key():
114+
# Tensor
115+
t1 = mt.random.randint(10, size=(10, 8), chunk_size=4)
116+
t2 = mt.random.randint(10, size=(10, 8), chunk_size=5)
117+
graph1 = (t1 + t2).build_graph(tile=False)
118+
tt1 = mt.random.randint(10, size=(10, 8), chunk_size=4)
119+
tt2 = mt.random.randint(10, size=(10, 8), chunk_size=5)
120+
graph2 = (tt1 + tt2).build_graph(tile=False)
121+
assert graph1.logic_key == graph2.logic_key
122+
t3 = mt.random.randint(10, size=(10, 8), chunk_size=6)
123+
tt3 = mt.random.randint(10, size=(10, 8), chunk_size=6)
124+
graph3 = (t1 + t3).build_graph(tile=False)
125+
graph4 = (t1 + tt3).build_graph(tile=False)
126+
assert graph1.logic_key != graph3.logic_key
127+
assert graph3.logic_key == graph4.logic_key
128+
t4 = mt.random.randint(10, size=(10, 8))
129+
graph5 = (t1 + t4).build_graph(tile=False)
130+
assert graph1.logic_key != graph5.logic_key
131+
132+
# Series
133+
s1 = md.Series([1, 3, 5, mt.nan, 6, 8])
134+
s2 = md.Series(np.random.randn(1000), chunk_size=100)
135+
graph1 = (s1 + s2).build_graph(tile=False)
136+
ss1 = md.Series([1, 3, 5, mt.nan, 6, 8])
137+
ss2 = md.Series(np.random.randn(1000), chunk_size=100)
138+
graph2 = (ss1 + ss2).build_graph(tile=False)
139+
assert graph1.logic_key == graph2.logic_key
140+
s3 = md.Series(np.random.randn(1000), chunk_size=200)
141+
ss3 = md.Series(np.random.randn(1000), chunk_size=200)
142+
graph3 = (s1 + s3).build_graph(tile=False)
143+
graph4 = (s1 + ss3).build_graph(tile=False)
144+
assert graph1.logic_key != graph3.logic_key
145+
assert graph3.logic_key == graph4.logic_key
146+
s4 = md.Series(np.random.randn(1000))
147+
graph5 = (s1 + s4).build_graph(tile=False)
148+
assert graph1.logic_key != graph5.logic_key
149+
150+
# DataFrame
151+
df1 = md.DataFrame(
152+
np.random.randint(0, 100, size=(100_000, 4)), columns=list("ABCD"), chunk_size=5
153+
)
154+
df2 = md.DataFrame(
155+
np.random.randint(0, 100, size=(100_000, 4)), columns=list("ABCD"), chunk_size=4
156+
)
157+
graph1 = (df1 + df2).build_graph(tile=False)
158+
ddf1 = md.DataFrame(
159+
np.random.randint(0, 100, size=(100_000, 4)), columns=list("ABCD"), chunk_size=5
160+
)
161+
ddf2 = md.DataFrame(
162+
np.random.randint(0, 100, size=(100_000, 4)), columns=list("ABCD"), chunk_size=4
163+
)
164+
graph2 = (ddf1 + ddf2).build_graph(tile=False)
165+
assert graph1.logic_key == graph2.logic_key
166+
df3 = md.DataFrame(
167+
np.random.randint(0, 100, size=(100_000, 4)), columns=list("ABCD"), chunk_size=3
168+
)
169+
ddf3 = md.DataFrame(
170+
np.random.randint(0, 100, size=(100_000, 4)), columns=list("ABCD"), chunk_size=3
171+
)
172+
graph3 = (df1 + df3).build_graph(tile=False)
173+
graph4 = (df1 + ddf3).build_graph(tile=False)
174+
assert graph1.logic_key != graph3.logic_key
175+
assert graph3.logic_key == graph4.logic_key
176+
df5 = md.DataFrame(
177+
np.random.randint(0, 100, size=(100_000, 4)), columns=list("ABCD")
178+
)
179+
graph5 = (df1 + df5).build_graph(tile=False)
180+
assert graph1.logic_key != graph5.logic_key
181+
graph6 = df1.describe().build_graph(tile=False)
182+
graph7 = df2.describe().build_graph(tile=False)
183+
assert graph6.logic_key != graph7.logic_key
184+
graph8 = df1.apply(lambda x: x.max() - x.min()).build_graph(tile=False)
185+
graph9 = df2.apply(lambda x: x.max() - x.min()).build_graph(tile=False)
186+
assert graph8.logic_key != graph9.logic_key
187+
pieces1 = [df1[:3], df1[3:7], df1[7:]]
188+
graph10 = md.concat(pieces1).build_graph(tile=False)
189+
pieces2 = [df2[:3], df2[3:7], df2[7:]]
190+
graph11 = md.concat(pieces2).build_graph(tile=False)
191+
assert graph10.logic_key != graph11.logic_key
192+
graph12 = md.merge(df1, df2, on="A", how="left").build_graph(tile=False)
193+
graph13 = md.merge(df1, df3, on="A", how="left").build_graph(tile=False)
194+
assert graph12.logic_key != graph13.logic_key
195+
graph14 = df2.groupby("A").sum().build_graph(tile=False)
196+
graph15 = df3.groupby("A").sum().build_graph(tile=False)
197+
assert graph14.logic_key != graph15.logic_key

0 commit comments

Comments
 (0)