Skip to content

Commit 9cb8257

Browse files
committed
Using merge instead of replace as the result updating strategy
1 parent c53e623 commit 9cb8257

File tree

2 files changed

+66
-37
lines changed

2 files changed

+66
-37
lines changed

mars/optimization/logical/core.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,8 @@ def _replace_subgraph(
135135
self,
136136
graph: Optional[EntityGraph],
137137
nodes_to_remove: Optional[Set[EntityType]],
138-
new_results: Optional[List[Entity]] = None,
138+
new_results: Optional[List[Entity]],
139+
results_to_remove: Optional[List[Entity]],
139140
):
140141
"""
141142
Replace the subgraph from the self._graph represented by a list of nodes with input graph.
@@ -148,19 +149,28 @@ def _replace_subgraph(
148149
The input graph. If it's none, no new node and edge will be added.
149150
nodes_to_remove : Set[EntityType], optional
150151
The nodes to be removed. All the edges connected with them are removed as well.
151-
new_results : List[EntityType], optional, default None
152-
The updated results of the graph. If it's None, then the results will not be updated.
152+
new_results : List[Entity], optional
153+
The new results to be added to the graph.
154+
results_to_remove : List[Entity], optional
155+
The results to be removed from the graph. If a result is not in self._graph.results, it will be ignored.
153156
154157
Raises
155158
------
156-
ReplaceSubgraphError
157-
If the input key of the removed node's successor can't be found in the subgraph.
158-
Or some of the nodes of the subgraph are in removed ones.
159+
ValueError
160+
1. If the input key of the removed node's successor can't be found in the subgraph.
161+
2. Or some of the nodes of the subgraph are in removed ones.
162+
3. Or the added result is not a valid output of any node in the updated graph.
159163
"""
160164
affected_successors = set()
161165

162166
output_to_node = dict()
163167
nodes_to_remove = nodes_to_remove or set()
168+
results_to_remove = results_to_remove or list()
169+
new_results = new_results or list()
170+
final_results = set(
171+
filter(lambda x: x not in results_to_remove, self._graph.results)
172+
)
173+
164174
if graph is not None:
165175
# Add the output key -> node of the subgraph
166176
for node in graph.iter_nodes():
@@ -169,6 +179,17 @@ def _replace_subgraph(
169179
for output in node.outputs:
170180
output_to_node[output.key] = node
171181

182+
# Add the output key -> node of the original graph
183+
for node in self._graph.iter_nodes():
184+
if node not in nodes_to_remove:
185+
for output in node.outputs:
186+
output_to_node[output.key] = node
187+
188+
for result in new_results:
189+
if result.key not in output_to_node:
190+
raise ValueError(f"Unknown result {result} to add")
191+
final_results.update(new_results)
192+
172193
for node in nodes_to_remove:
173194
for affected_successor in self._graph.iter_successors(node):
174195
if affected_successor not in nodes_to_remove:
@@ -180,17 +201,13 @@ def _replace_subgraph(
180201
raise ValueError(
181202
f"The output {inp} of node {affected_successor} is missing in the subgraph"
182203
)
204+
# Here all the pre-check are passed, we start to replace the subgraph
183205
for node in nodes_to_remove:
184206
self._graph.remove_node(node)
185207

186208
if graph is None:
187209
return
188210

189-
# Add the output key -> node of the original graph
190-
for node in self._graph.iter_nodes():
191-
for output in node.outputs:
192-
output_to_node[output.key] = node
193-
194211
for node in graph.iter_nodes():
195212
self._graph.add_node(node)
196213

@@ -199,8 +216,7 @@ def _replace_subgraph(
199216
pred_node = output_to_node[inp.key]
200217
self._graph.add_edge(pred_node, node)
201218

202-
if new_results is not None:
203-
self._graph.results = list(new_results)
219+
self._graph.results = list(final_results)
204220

205221
def _add_collapsable_predecessor(self, node: EntityType, predecessor: EntityType):
206222
pred_original = self._records.get_original_entity(predecessor, predecessor)

mars/optimization/logical/tests/test_core.py

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ class _MockRule(OptimizationRule):
2424
def apply(self) -> bool:
2525
pass
2626

27-
def replace_subgraph(self, graph, removed_nodes, new_results=None):
28-
self._replace_subgraph(graph, removed_nodes, new_results)
27+
def replace_subgraph(self, graph, nodes_to_remove, new_results, results_to_remove):
28+
self._replace_subgraph(graph, nodes_to_remove, new_results, results_to_remove)
2929

3030

3131
def test_replace_tileable_subgraph():
@@ -78,11 +78,15 @@ def test_replace_tileable_subgraph():
7878
c2 = g1.successors(key_to_node[s2.key])[0]
7979
c5 = g1.successors(key_to_node[s5.key])[0]
8080

81-
expected_results = [v8.outputs[0]]
81+
new_results = [v8.outputs[0]]
82+
removed_results = [
83+
v6.outputs[0],
84+
v8.outputs[0], # v8.outputs[0] is not in the original results, so we ignore it.
85+
]
8286
r.replace_subgraph(
83-
g2, {key_to_node[op.key] for op in [v3, v4, v6]}, expected_results
87+
g2, {key_to_node[op.key] for op in [v3, v4, v6]}, new_results, removed_results
8488
)
85-
assert g1.results == expected_results
89+
assert g1.results == new_results
8690

8791
expected_nodes = {s1, c1, v1, s2, c2, v2, s5, c5, v5, v7, v8}
8892
assert set(g1) == {key_to_node[n.key] for n in expected_nodes}
@@ -110,10 +114,10 @@ def test_replace_tileable_subgraph():
110114
def test_replace_null_subgraph():
111115
"""
112116
Original Graph:
113-
s1 ---> c1 ---> v1 ---> v3 <--- v2 <--- c2 <--- s2
117+
s1 ---> c1 ---> v1 ---> v3(out) <--- v2 <--- c2 <--- s2
114118
115119
Target Graph:
116-
c1 ---> v1 ---> v3 <--- v2 <--- c2
120+
c1 ---> v1 ---> v3 <--- v2(out) <--- c2
117121
118122
The nodes [s1, s2] will be removed.
119123
Subgraph is None
@@ -129,30 +133,39 @@ def test_replace_null_subgraph():
129133
c2 = g1.successors(key_to_node[s2.key])[0]
130134
r = _MockRule(g1, None, None)
131135
expected_results = [v3.outputs[0]]
136+
132137
# delete c5 s5 will fail
133138
with pytest.raises(ValueError) as e:
134-
r.replace_subgraph(None, {key_to_node[op.key] for op in [s1, s2]})
135-
assert g1.results == expected_results
136-
assert set(g1) == {key_to_node[n.key] for n in {s1, c1, v1, s2, c2, v2, v3}}
137-
expected_edges = {
138-
s1: [c1],
139-
c1: [v1],
140-
v1: [v3],
141-
s2: [c2],
142-
c2: [v2],
143-
v2: [v3],
144-
v3: [],
145-
}
146-
for pred, successors in expected_edges.items():
147-
pred_node = key_to_node[pred.key]
139+
r.replace_subgraph(
140+
None, {key_to_node[op.key] for op in [s1, s2]}, None, [v2.outputs[0]]
141+
)
142+
143+
assert g1.results == expected_results
144+
assert set(g1) == {key_to_node[n.key] for n in {s1, c1, v1, s2, c2, v2, v3}}
145+
expected_edges = {
146+
s1: [c1],
147+
c1: [v1],
148+
v1: [v3],
149+
s2: [c2],
150+
c2: [v2],
151+
v2: [v3],
152+
v3: [],
153+
}
154+
for pred, successors in expected_edges.items():
155+
pred_node = key_to_node[pred.key]
148156
assert g1.count_successors(pred_node) == len(successors)
149157
for successor in successors:
150158
assert g1.has_successor(pred_node, key_to_node[successor.key])
151159

152160
c1.inputs.clear()
153161
c2.inputs.clear()
154-
r.replace_subgraph(None, {key_to_node[op.key] for op in [s1, s2]})
155-
assert g1.results == expected_results
162+
r.replace_subgraph(
163+
None,
164+
{key_to_node[op.key] for op in [s1, s2]},
165+
[v2.outputs[0]],
166+
[v3.outputs[0]],
167+
)
168+
assert g1.results == [v2.outputs[0]]
156169
assert set(g1) == {key_to_node[n.key] for n in {c1, v1, c2, v2, v3}}
157170
expected_edges = {
158171
c1: [v1],
@@ -198,7 +211,7 @@ def test_replace_subgraph_without_removing_nodes():
198211
c2 = g1.successors(key_to_node[s2.key])[0]
199212
c3 = g2.successors(key_to_node[s3.key])[0]
200213
r = _MockRule(g1, None, None)
201-
r.replace_subgraph(g2, None, expected_results)
214+
r.replace_subgraph(g2, None, [v3.outputs[0]], None)
202215
assert g1.results == expected_results
203216
assert set(g1) == {
204217
key_to_node[n.key] for n in {s1, c1, v1, s2, c2, v2, s3, c3, v3, v4}

0 commit comments

Comments
 (0)