Skip to content

Commit 79b218f

Browse files
committed
Using merge instead of replace as the result updating strategy
1 parent c53e623 commit 79b218f

File tree

2 files changed

+68
-37
lines changed

2 files changed

+68
-37
lines changed

mars/optimization/logical/core.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,8 @@ def _replace_subgraph(
135135
self,
136136
graph: Optional[EntityGraph],
137137
nodes_to_remove: Optional[Set[EntityType]],
138-
new_results: Optional[List[Entity]] = None,
138+
new_results: Optional[List[Entity]],
139+
results_to_remove: Optional[List[Entity]],
139140
):
140141
"""
141142
Replace the subgraph from the self._graph represented by a list of nodes with input graph.
@@ -148,19 +149,28 @@ def _replace_subgraph(
148149
The input graph. If it's none, no new node and edge will be added.
149150
nodes_to_remove : Set[EntityType], optional
150151
The nodes to be removed. All the edges connected with them are removed as well.
151-
new_results : List[EntityType], optional, default None
152-
The updated results of the graph. If it's None, then the results will not be updated.
152+
new_results : List[Entity], optional
153+
The new results to be added to the graph.
154+
results_to_remove : List[Entity], optional
155+
The results to be removed from the graph. If a result is not in self._graph.results, it will be ignored.
153156
154157
Raises
155158
------
156-
ReplaceSubgraphError
157-
If the input key of the removed node's successor can't be found in the subgraph.
158-
Or some of the nodes of the subgraph are in removed ones.
159+
ValueError
160+
1. If the input key of the removed node's successor can't be found in the subgraph.
161+
2. Or some of the nodes of the subgraph are in removed ones.
162+
3. Or the added result is not a valid output of any node in the updated graph.
159163
"""
160164
affected_successors = set()
161165

162166
output_to_node = dict()
163167
nodes_to_remove = nodes_to_remove or set()
168+
results_to_remove = results_to_remove or list()
169+
new_results = new_results or list()
170+
final_results = set(
171+
filter(lambda x: x not in results_to_remove, self._graph.results)
172+
)
173+
164174
if graph is not None:
165175
# Add the output key -> node of the subgraph
166176
for node in graph.iter_nodes():
@@ -169,6 +179,17 @@ def _replace_subgraph(
169179
for output in node.outputs:
170180
output_to_node[output.key] = node
171181

182+
# Add the output key -> node of the original graph
183+
for node in self._graph.iter_nodes():
184+
if node not in nodes_to_remove:
185+
for output in node.outputs:
186+
output_to_node[output.key] = node
187+
188+
for result in new_results:
189+
if result.key not in output_to_node:
190+
raise ValueError(f"Unknown result {result} to add")
191+
final_results.update(new_results)
192+
172193
for node in nodes_to_remove:
173194
for affected_successor in self._graph.iter_successors(node):
174195
if affected_successor not in nodes_to_remove:
@@ -180,17 +201,13 @@ def _replace_subgraph(
180201
raise ValueError(
181202
f"The output {inp} of node {affected_successor} is missing in the subgraph"
182203
)
204+
# Here all the pre-check are passed, we start to replace the subgraph
183205
for node in nodes_to_remove:
184206
self._graph.remove_node(node)
185207

186208
if graph is None:
187209
return
188210

189-
# Add the output key -> node of the original graph
190-
for node in self._graph.iter_nodes():
191-
for output in node.outputs:
192-
output_to_node[output.key] = node
193-
194211
for node in graph.iter_nodes():
195212
self._graph.add_node(node)
196213

@@ -199,8 +216,7 @@ def _replace_subgraph(
199216
pred_node = output_to_node[inp.key]
200217
self._graph.add_edge(pred_node, node)
201218

202-
if new_results is not None:
203-
self._graph.results = list(new_results)
219+
self._graph.results = list(final_results)
204220

205221
def _add_collapsable_predecessor(self, node: EntityType, predecessor: EntityType):
206222
pred_original = self._records.get_original_entity(predecessor, predecessor)

mars/optimization/logical/tests/test_core.py

Lines changed: 39 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import itertools
15+
from typing import Optional
16+
1517
import pytest
1618

1719

@@ -24,8 +26,8 @@ class _MockRule(OptimizationRule):
2426
def apply(self) -> bool:
2527
pass
2628

27-
def replace_subgraph(self, graph, removed_nodes, new_results=None):
28-
self._replace_subgraph(graph, removed_nodes, new_results)
29+
def replace_subgraph(self, graph, nodes_to_remove, new_results, results_to_remove):
30+
self._replace_subgraph(graph, nodes_to_remove, new_results, results_to_remove)
2931

3032

3133
def test_replace_tileable_subgraph():
@@ -78,11 +80,15 @@ def test_replace_tileable_subgraph():
7880
c2 = g1.successors(key_to_node[s2.key])[0]
7981
c5 = g1.successors(key_to_node[s5.key])[0]
8082

81-
expected_results = [v8.outputs[0]]
83+
new_results = [v8.outputs[0]]
84+
removed_results = [
85+
v6.outputs[0],
86+
v8.outputs[0], # v8.outputs[0] is not in the original results, so we ignore it.
87+
]
8288
r.replace_subgraph(
83-
g2, {key_to_node[op.key] for op in [v3, v4, v6]}, expected_results
89+
g2, {key_to_node[op.key] for op in [v3, v4, v6]}, new_results, removed_results
8490
)
85-
assert g1.results == expected_results
91+
assert g1.results == new_results
8692

8793
expected_nodes = {s1, c1, v1, s2, c2, v2, s5, c5, v5, v7, v8}
8894
assert set(g1) == {key_to_node[n.key] for n in expected_nodes}
@@ -110,10 +116,10 @@ def test_replace_tileable_subgraph():
110116
def test_replace_null_subgraph():
111117
"""
112118
Original Graph:
113-
s1 ---> c1 ---> v1 ---> v3 <--- v2 <--- c2 <--- s2
119+
s1 ---> c1 ---> v1 ---> v3(out) <--- v2 <--- c2 <--- s2
114120
115121
Target Graph:
116-
c1 ---> v1 ---> v3 <--- v2 <--- c2
122+
c1 ---> v1 ---> v3 <--- v2(out) <--- c2
117123
118124
The nodes [s1, s2] will be removed.
119125
Subgraph is None
@@ -129,30 +135,39 @@ def test_replace_null_subgraph():
129135
c2 = g1.successors(key_to_node[s2.key])[0]
130136
r = _MockRule(g1, None, None)
131137
expected_results = [v3.outputs[0]]
138+
132139
# delete c5 s5 will fail
133140
with pytest.raises(ValueError) as e:
134-
r.replace_subgraph(None, {key_to_node[op.key] for op in [s1, s2]})
135-
assert g1.results == expected_results
136-
assert set(g1) == {key_to_node[n.key] for n in {s1, c1, v1, s2, c2, v2, v3}}
137-
expected_edges = {
138-
s1: [c1],
139-
c1: [v1],
140-
v1: [v3],
141-
s2: [c2],
142-
c2: [v2],
143-
v2: [v3],
144-
v3: [],
145-
}
146-
for pred, successors in expected_edges.items():
147-
pred_node = key_to_node[pred.key]
141+
r.replace_subgraph(
142+
None, {key_to_node[op.key] for op in [s1, s2]}, None, [v2.outputs[0]]
143+
)
144+
145+
assert g1.results == expected_results
146+
assert set(g1) == {key_to_node[n.key] for n in {s1, c1, v1, s2, c2, v2, v3}}
147+
expected_edges = {
148+
s1: [c1],
149+
c1: [v1],
150+
v1: [v3],
151+
s2: [c2],
152+
c2: [v2],
153+
v2: [v3],
154+
v3: [],
155+
}
156+
for pred, successors in expected_edges.items():
157+
pred_node = key_to_node[pred.key]
148158
assert g1.count_successors(pred_node) == len(successors)
149159
for successor in successors:
150160
assert g1.has_successor(pred_node, key_to_node[successor.key])
151161

152162
c1.inputs.clear()
153163
c2.inputs.clear()
154-
r.replace_subgraph(None, {key_to_node[op.key] for op in [s1, s2]})
155-
assert g1.results == expected_results
164+
r.replace_subgraph(
165+
None,
166+
{key_to_node[op.key] for op in [s1, s2]},
167+
[v2.outputs[0]],
168+
[v3.outputs[0]],
169+
)
170+
assert g1.results == [v2.outputs[0]]
156171
assert set(g1) == {key_to_node[n.key] for n in {c1, v1, c2, v2, v3}}
157172
expected_edges = {
158173
c1: [v1],
@@ -198,7 +213,7 @@ def test_replace_subgraph_without_removing_nodes():
198213
c2 = g1.successors(key_to_node[s2.key])[0]
199214
c3 = g2.successors(key_to_node[s3.key])[0]
200215
r = _MockRule(g1, None, None)
201-
r.replace_subgraph(g2, None, expected_results)
216+
r.replace_subgraph(g2, None, [v3.outputs[0]], None)
202217
assert g1.results == expected_results
203218
assert set(g1) == {
204219
key_to_node[n.key] for n in {s1, c1, v1, s2, c2, v2, s3, c3, v3, v4}

0 commit comments

Comments
 (0)