Skip to content

Commit 12525c0

Browse files
committed
Add replace_subgraph with tests
1 parent 3418861 commit 12525c0

File tree

4 files changed

+315
-1
lines changed

4 files changed

+315
-1
lines changed

mars/core/entity/core.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ def __init__(self, *args, **kwargs):
4242
def op(self):
4343
return self._op
4444

45+
@property
46+
def outputs(self):
47+
return self._op.outputs
48+
4549
@property
4650
def inputs(self):
4751
return self.op.inputs

mars/optimization/logical/core.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import functools
15+
import itertools
1516
import weakref
1617
from abc import ABC, abstractmethod
1718
from collections import defaultdict
1819
from dataclasses import dataclass
1920
from enum import Enum
2021
from typing import Dict, List, Optional, Type, Set
2122

22-
from ...core import OperandType, EntityType, enter_mode
23+
from ...core import OperandType, EntityType, enter_mode, Entity
2324
from ...core.graph import EntityGraph
2425
from ...utils import implements
2526

@@ -130,6 +131,77 @@ def _replace_node(self, original_node: EntityType, new_node: EntityType):
130131
for succ in successors:
131132
self._graph.add_edge(new_node, succ)
132133

134+
def _replace_subgraph(
135+
self,
136+
graph: Optional[EntityGraph],
137+
removed_nodes: Optional[Set[EntityType]],
138+
new_results: Optional[List[Entity]] = None,
139+
):
140+
"""
141+
Replace the subgraph from the self._graph represented by a list of nodes with input graph.
142+
It will delete the nodes in removed_nodes with all linked edges first, and then add (or update if it's still
143+
existed in self._graph) the nodes and edges of the input graph.
144+
145+
Parameters
146+
----------
147+
graph : EntityGraph, optional
148+
The input graph. If it's none, no new node and edge will be added.
149+
removed_nodes : Set[EntityType], optional
150+
The nodes to be removed. All the edges connected with them are removed as well.
151+
new_results : List[EntityType], optional, default None
152+
The updated results of the graph. If it's None, then the results will not be updated.
153+
154+
Raises
155+
------
156+
ReplaceSubgraphError
157+
If the input key of the removed node's successor can't be found in the subgraph.
158+
Or some of the nodes of the subgraph are in removed ones.
159+
"""
160+
infected_successors = set()
161+
162+
output_to_node = dict()
163+
removed_nodes = removed_nodes or set()
164+
if graph is not None:
165+
# Add the output key -> node of the subgraph
166+
for node in graph.iter_nodes():
167+
if node in removed_nodes:
168+
raise ReplaceSubgraphError(f"The node {node} is in the removed set")
169+
for output in node.outputs:
170+
output_to_node[output.key] = node
171+
172+
for node in removed_nodes:
173+
for infected_successor in self._graph.iter_successors(node):
174+
if infected_successor not in removed_nodes:
175+
infected_successors.add(infected_successor)
176+
# Check whether infected successors' inputs are in subgraph
177+
for infected_successor in infected_successors:
178+
for inp in infected_successor.inputs:
179+
if inp.key not in output_to_node:
180+
raise ReplaceSubgraphError(
181+
f"The output {inp} of node {infected_successor} is missing in the subgraph"
182+
)
183+
for node in removed_nodes:
184+
self._graph.remove_node(node)
185+
186+
if graph is None:
187+
return
188+
189+
# Add the output key -> node of the original graph
190+
for node in self._graph.iter_nodes():
191+
for output in node.outputs:
192+
output_to_node[output.key] = node
193+
194+
for node in graph.iter_nodes():
195+
self._graph.add_node(node)
196+
197+
for node in itertools.chain(graph.iter_nodes(), infected_successors):
198+
for inp in node.inputs:
199+
pred_node = output_to_node[inp.key]
200+
self._graph.add_edge(pred_node, node)
201+
202+
if new_results is not None:
203+
self._graph.results = new_results.copy()
204+
133205
def _add_collapsable_predecessor(self, node: EntityType, predecessor: EntityType):
134206
pred_original = self._records.get_original_entity(predecessor, predecessor)
135207
if predecessor not in self._preds_to_remove:
@@ -283,3 +355,7 @@ def optimize(cls, graph: EntityGraph) -> OptimizationRecords:
283355
graph.results = new_results
284356

285357
return records
358+
359+
360+
class ReplaceSubgraphError(Exception):
361+
pass
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 1999-2021 Alibaba Group Holding Ltd.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
# Copyright 1999-2021 Alibaba Group Holding Ltd.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import itertools
15+
import pytest
16+
17+
import mars.tensor as mt
18+
from ..core import OptimizationRule, ReplaceSubgraphError
19+
from .... import dataframe as md
20+
21+
22+
class _MockRule(OptimizationRule):
23+
def apply(self) -> bool:
24+
pass
25+
26+
def replace_subgraph(self, graph, removed_nodes, new_results=None):
27+
self._replace_subgraph(graph, removed_nodes, new_results)
28+
29+
30+
def test_replace_tileable_subgraph():
31+
"""
32+
Original Graph:
33+
s1 ---> c1 ---> v1 ---> v4 ----> v6(output) <--- v5 <--- c5 <--- s5
34+
| ^
35+
| |
36+
V |
37+
v3 ------|
38+
^
39+
|
40+
s2 ---> c2 ---> v2
41+
42+
Target Graph:
43+
s1 ---> c1 ---> v1 ---> v7 ----> v8(output) <--- v5 <--- c5 <--- s5
44+
^
45+
|
46+
s2 ---> c2 ---> v2
47+
48+
The nodes [v3, v4, v6] will be removed.
49+
Subgraph only contains [v7, v8]
50+
"""
51+
s1 = mt.random.randint(0, 100, size=(5, 4))
52+
v1 = md.DataFrame(s1, columns=list("ABCD"), chunk_size=5)
53+
s2 = mt.random.randint(0, 100, size=(5, 4))
54+
v2 = md.DataFrame(s2, columns=list("ABCD"), chunk_size=5)
55+
v3 = v1.add(v2)
56+
v4 = v3.add(v1)
57+
s5 = mt.random.randint(0, 100, size=(5, 4))
58+
v5 = md.DataFrame(s5, columns=list("ABCD"), chunk_size=4)
59+
v6 = v5.sub(v4)
60+
g1 = v6.build_graph()
61+
v7 = v1.sub(v2)
62+
v8 = v7.add(v5)
63+
g2 = v8.build_graph()
64+
65+
# Here we use a trick way to construct the subgraph for test only
66+
key_to_node = dict()
67+
for node in g2.iter_nodes():
68+
key_to_node[node.key] = node
69+
for key, node in key_to_node.items():
70+
if key != v7.key and key != v8.key:
71+
g2.remove_node(node)
72+
r = _MockRule(g1, None, None)
73+
for node in g1.iter_nodes():
74+
key_to_node[node.key] = node
75+
76+
c1 = g1.successors(key_to_node[s1.key])[0]
77+
c2 = g1.successors(key_to_node[s2.key])[0]
78+
c5 = g1.successors(key_to_node[s5.key])[0]
79+
80+
expected_results = [v8.outputs[0]]
81+
r.replace_subgraph(
82+
g2, {key_to_node[op.key] for op in [v3, v4, v6]}, expected_results
83+
)
84+
assert g1.results == expected_results
85+
86+
expected_nodes = {s1, c1, v1, s2, c2, v2, s5, c5, v5, v7, v8}
87+
assert set(g1) == {key_to_node[n.key] for n in expected_nodes}
88+
89+
expected_edges = {
90+
s1: [c1],
91+
c1: [v1],
92+
v1: [v7],
93+
s2: [c2],
94+
c2: [v2],
95+
v2: [v7],
96+
s5: [c5],
97+
c5: [v5],
98+
v5: [v8],
99+
v7: [v8],
100+
v8: [],
101+
}
102+
for pred, successors in expected_edges.items():
103+
pred_node = key_to_node[pred.key]
104+
assert g1.count_successors(pred_node) == len(successors)
105+
for successor in successors:
106+
assert g1.has_successor(pred_node, key_to_node[successor.key])
107+
108+
109+
def test_replace_null_subgraph():
110+
"""
111+
Original Graph:
112+
s1 ---> c1 ---> v1 ---> v3 <--- v2 <--- c2 <--- s2
113+
114+
Target Graph:
115+
c1 ---> v1 ---> v3 <--- v2 <--- c2
116+
117+
The nodes [s1, s2] will be removed.
118+
Subgraph is None
119+
"""
120+
s1 = mt.random.randint(0, 100, size=(10, 4))
121+
v1 = md.DataFrame(s1, columns=list("ABCD"), chunk_size=5)
122+
s2 = mt.random.randint(0, 100, size=(10, 4))
123+
v2 = md.DataFrame(s2, columns=list("ABCD"), chunk_size=5)
124+
v3 = v1.add(v2)
125+
g1 = v3.build_graph()
126+
key_to_node = {node.key: node for node in g1.iter_nodes()}
127+
c1 = g1.successors(key_to_node[s1.key])[0]
128+
c2 = g1.successors(key_to_node[s2.key])[0]
129+
r = _MockRule(g1, None, None)
130+
expected_results = [v3.outputs[0]]
131+
# delete c5 s5 will fail
132+
with pytest.raises(ReplaceSubgraphError) as e:
133+
r.replace_subgraph(None, {key_to_node[op.key] for op in [s1, s2]})
134+
assert g1.results == expected_results
135+
assert set(g1) == {key_to_node[n.key] for n in {s1, c1, v1, s2, c2, v2, v3}}
136+
expected_edges = {
137+
s1: [c1],
138+
c1: [v1],
139+
v1: [v3],
140+
s2: [c2],
141+
c2: [v2],
142+
v2: [v3],
143+
v3: [],
144+
}
145+
for pred, successors in expected_edges.items():
146+
pred_node = key_to_node[pred.key]
147+
assert g1.count_successors(pred_node) == len(successors)
148+
for successor in successors:
149+
assert g1.has_successor(pred_node, key_to_node[successor.key])
150+
151+
c1.inputs.clear()
152+
c2.inputs.clear()
153+
r.replace_subgraph(None, {key_to_node[op.key] for op in [s1, s2]})
154+
assert g1.results == expected_results
155+
assert set(g1) == {key_to_node[n.key] for n in {c1, v1, c2, v2, v3}}
156+
expected_edges = {
157+
c1: [v1],
158+
v1: [v3],
159+
c2: [v2],
160+
v2: [v3],
161+
v3: [],
162+
}
163+
for pred, successors in expected_edges.items():
164+
pred_node = key_to_node[pred.key]
165+
assert g1.count_successors(pred_node) == len(successors)
166+
for successor in successors:
167+
assert g1.has_successor(pred_node, key_to_node[successor.key])
168+
169+
170+
def test_replace_subgraph_without_removing_nodes():
171+
"""
172+
Original Graph:
173+
s1 ---> c1 ---> v1 ---> v4 <--- v2 <--- c2 <--- s2
174+
175+
Target Graph:
176+
s1 ---> c1 ---> v1 ---> v4 <--- v2 <--- c2 <--- s2
177+
s3 ---> c3 ---> v3
178+
179+
Nothing will be removed.
180+
Subgraph only contains [s3, c3, v3]
181+
"""
182+
s1 = mt.random.randint(0, 100, size=(10, 4))
183+
v1 = md.DataFrame(s1, columns=list("ABCD"), chunk_size=5)
184+
s2 = mt.random.randint(0, 100, size=(10, 4))
185+
v2 = md.DataFrame(s2, columns=list("ABCD"), chunk_size=5)
186+
v4 = v1.add(v2)
187+
g1 = v4.build_graph()
188+
189+
s3 = mt.random.randint(0, 100, size=(10, 4))
190+
v3 = md.DataFrame(s3, columns=list("ABCD"), chunk_size=5)
191+
g2 = v3.build_graph()
192+
key_to_node = {
193+
node.key: node for node in itertools.chain(g1.iter_nodes(), g2.iter_nodes())
194+
}
195+
expected_results = [v3.outputs[0], v4.outputs[0]]
196+
c1 = g1.successors(key_to_node[s1.key])[0]
197+
c2 = g1.successors(key_to_node[s2.key])[0]
198+
c3 = g2.successors(key_to_node[s3.key])[0]
199+
r = _MockRule(g1, None, None)
200+
r.replace_subgraph(g2, None, expected_results)
201+
assert g1.results == expected_results
202+
assert set(g1) == {
203+
key_to_node[n.key] for n in {s1, c1, v1, s2, c2, v2, s3, c3, v3, v4}
204+
}
205+
expected_edges = {
206+
s1: [c1],
207+
c1: [v1],
208+
v1: [v4],
209+
s2: [c2],
210+
c2: [v2],
211+
v2: [v4],
212+
s3: [c3],
213+
c3: [v3],
214+
v3: [],
215+
v4: [],
216+
}
217+
for pred, successors in expected_edges.items():
218+
pred_node = key_to_node[pred.key]
219+
assert g1.count_successors(pred_node) == len(successors)
220+
for successor in successors:
221+
assert g1.has_successor(pred_node, key_to_node[successor.key])

0 commit comments

Comments
 (0)