Skip to content

Commit dd02ba9

Browse files
committed
Change results updating algorithm
1 parent e2eaefc commit dd02ba9

File tree

2 files changed

+29
-34
lines changed

2 files changed

+29
-34
lines changed

mars/optimization/logical/core.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,7 @@ def _replace_subgraph(
135135
self,
136136
graph: Optional[EntityGraph],
137137
nodes_to_remove: Optional[Set[EntityType]],
138-
new_results: Optional[List[Entity]],
139-
results_to_remove: Optional[List[Entity]],
138+
new_results: Optional[List[Entity]] = None,
140139
):
141140
"""
142141
Replace the subgraph from the self._graph represented by a list of nodes with input graph.
@@ -149,27 +148,24 @@ def _replace_subgraph(
149148
The input graph. If it's none, no new node and edge will be added.
150149
nodes_to_remove : Set[EntityType], optional
151150
The nodes to be removed. All the edges connected with them are removed as well.
152-
new_results : List[Entity], optional
153-
The new results to be added to the graph.
154-
results_to_remove : List[Entity], optional
155-
The results to be removed from the graph. If a result is not in self._graph.results, it will be ignored.
151+
new_results : List[Entity], optional, default None
152+
The new results to be replaced to the original by their keys.
156153
157154
Raises
158155
------
159156
ValueError
160157
1. If the input key of the removed node's successor can't be found in the subgraph.
161158
2. Or some of the nodes of the subgraph are in removed ones.
162-
3. Or the added result is not a valid output of any node in the updated graph.
159+
3. Or some of the removed nodes are also in the results.
160+
4. Or the key of the new result can't be found in the original results.
163161
"""
164162
affected_successors = set()
165-
166163
output_to_node = dict()
167164
nodes_to_remove = nodes_to_remove or set()
168-
results_to_remove = results_to_remove or list()
169165
new_results = new_results or list()
170-
final_results = set(
171-
filter(lambda x: x not in results_to_remove, self._graph.results)
172-
)
166+
result_indices = {
167+
result.key: idx for idx, result in enumerate(self._graph.results)
168+
}
173169

174170
if graph is not None:
175171
# Add the output key -> node of the subgraph
@@ -185,10 +181,12 @@ def _replace_subgraph(
185181
for output in node.outputs:
186182
output_to_node[output.key] = node
187183

184+
# Check if the updated result is valid
188185
for result in new_results:
186+
if result.key not in result_indices:
187+
raise ValueError(f"Unknown result {result} to replace")
189188
if result.key not in output_to_node:
190-
raise ValueError(f"Unknown result {result} to add")
191-
final_results.update(new_results)
189+
raise ValueError(f"The result {result} is missing in the updated graph")
192190

193191
for node in nodes_to_remove:
194192
for affected_successor in self._graph.iter_successors(node):
@@ -216,7 +214,8 @@ def _replace_subgraph(
216214
pred_node = output_to_node[inp.key]
217215
self._graph.add_edge(pred_node, node)
218216

219-
self._graph.results = list(final_results)
217+
for result in new_results:
218+
self._graph.results[result_indices[result.key]] = result
220219

221220
def _add_collapsable_predecessor(self, node: EntityType, predecessor: EntityType):
222221
pred_original = self._records.get_original_entity(predecessor, predecessor)

mars/optimization/logical/tests/test_core.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ class _MockRule(OptimizationRule):
2424
def apply(self) -> bool:
2525
pass
2626

27-
def replace_subgraph(self, graph, nodes_to_remove, new_results, results_to_remove):
28-
self._replace_subgraph(graph, nodes_to_remove, new_results, results_to_remove)
27+
def replace_subgraph(self, graph, nodes_to_remove, new_results=None):
28+
self._replace_subgraph(graph, nodes_to_remove, new_results)
2929

3030

3131
def test_replace_tileable_subgraph():
@@ -61,8 +61,9 @@ def test_replace_tileable_subgraph():
6161
g1 = v6.build_graph()
6262
v7 = v1.sub(v2)
6363
v8 = v7.add(v5)
64+
v8._key = v6.key
65+
v8.outputs[0]._key = v6.key
6466
g2 = v8.build_graph()
65-
6667
# Here we use a trick way to construct the subgraph for test only
6768
key_to_node = dict()
6869
for node in g2.iter_nodes():
@@ -79,15 +80,12 @@ def test_replace_tileable_subgraph():
7980
c5 = g1.successors(key_to_node[s5.key])[0]
8081

8182
new_results = [v8.outputs[0]]
82-
removed_results = [
83-
v6.outputs[0],
84-
v8.outputs[0], # v8.outputs[0] is not in the original results, so we ignore it.
85-
]
86-
r.replace_subgraph(
87-
g2, {key_to_node[op.key] for op in [v3, v4, v6]}, new_results, removed_results
88-
)
83+
r.replace_subgraph(g2, {key_to_node[op.key] for op in [v3, v4, v6]}, new_results)
8984
assert g1.results == new_results
90-
85+
for node in g1.iter_nodes():
86+
if node.key == v8.key:
87+
key_to_node[v8.key] = node
88+
break
9189
expected_nodes = {s1, c1, v1, s2, c2, v2, s5, c5, v5, v7, v8}
9290
assert set(g1) == {key_to_node[n.key] for n in expected_nodes}
9391

@@ -117,7 +115,7 @@ def test_replace_null_subgraph():
117115
s1 ---> c1 ---> v1 ---> v3(out) <--- v2 <--- c2 <--- s2
118116
119117
Target Graph:
120-
c1 ---> v1 ---> v3 <--- v2(out) <--- c2
118+
c1 ---> v1 ---> v3(out) <--- v2 <--- c2
121119
122120
The nodes [s1, s2] will be removed.
123121
Subgraph is None
@@ -137,7 +135,7 @@ def test_replace_null_subgraph():
137135
# delete c5 s5 will fail
138136
with pytest.raises(ValueError):
139137
r.replace_subgraph(
140-
None, {key_to_node[op.key] for op in [s1, s2]}, None, [v2.outputs[0]]
138+
None, {key_to_node[op.key] for op in [s1, s2]}, [v2.outputs[0]]
141139
)
142140

143141
assert g1.results == expected_results
@@ -161,11 +159,9 @@ def test_replace_null_subgraph():
161159
c2.inputs.clear()
162160
r.replace_subgraph(
163161
None,
164-
{key_to_node[op.key] for op in [s1, s2]},
165-
[v2.outputs[0]],
166-
[v3.outputs[0]],
162+
{key_to_node[op.key] for op in [s1, s2]}
167163
)
168-
assert g1.results == [v2.outputs[0]]
164+
assert g1.results == expected_results
169165
assert set(g1) == {key_to_node[n.key] for n in {c1, v1, c2, v2, v3}}
170166
expected_edges = {
171167
c1: [v1],
@@ -206,12 +202,12 @@ def test_replace_subgraph_without_removing_nodes():
206202
key_to_node = {
207203
node.key: node for node in itertools.chain(g1.iter_nodes(), g2.iter_nodes())
208204
}
209-
expected_results = [v3.outputs[0], v4.outputs[0]]
205+
expected_results = [v4.outputs[0]]
210206
c1 = g1.successors(key_to_node[s1.key])[0]
211207
c2 = g1.successors(key_to_node[s2.key])[0]
212208
c3 = g2.successors(key_to_node[s3.key])[0]
213209
r = _MockRule(g1, None, None)
214-
r.replace_subgraph(g2, None, [v3.outputs[0]], None)
210+
r.replace_subgraph(g2, None)
215211
assert g1.results == expected_results
216212
assert set(g1) == {
217213
key_to_node[n.key] for n in {s1, c1, v1, s2, c2, v2, s3, c3, v3, v4}

0 commit comments

Comments
 (0)