1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414import itertools
15+ from typing import Optional
16+
1517import pytest
1618
1719
@@ -24,8 +26,8 @@ class _MockRule(OptimizationRule):
2426 def apply (self ) -> bool :
2527 pass
2628
27- def replace_subgraph (self , graph , removed_nodes , new_results = None ):
28- self ._replace_subgraph (graph , removed_nodes , new_results )
29+ def replace_subgraph (self , graph , nodes_to_remove , new_results , results_to_remove ):
30+ self ._replace_subgraph (graph , nodes_to_remove , new_results , results_to_remove )
2931
3032
3133def test_replace_tileable_subgraph ():
@@ -78,11 +80,15 @@ def test_replace_tileable_subgraph():
7880 c2 = g1 .successors (key_to_node [s2 .key ])[0 ]
7981 c5 = g1 .successors (key_to_node [s5 .key ])[0 ]
8082
81- expected_results = [v8 .outputs [0 ]]
83+ new_results = [v8 .outputs [0 ]]
84+ removed_results = [
85+ v6 .outputs [0 ],
86+ v8 .outputs [0 ], # v8.outputs[0] is not in the original results, so we ignore it.
87+ ]
8288 r .replace_subgraph (
83- g2 , {key_to_node [op .key ] for op in [v3 , v4 , v6 ]}, expected_results
89+ g2 , {key_to_node [op .key ] for op in [v3 , v4 , v6 ]}, new_results , removed_results
8490 )
85- assert g1 .results == expected_results
91+ assert g1 .results == new_results
8692
8793 expected_nodes = {s1 , c1 , v1 , s2 , c2 , v2 , s5 , c5 , v5 , v7 , v8 }
8894 assert set (g1 ) == {key_to_node [n .key ] for n in expected_nodes }
@@ -110,10 +116,10 @@ def test_replace_tileable_subgraph():
110116def test_replace_null_subgraph ():
111117 """
112118 Original Graph:
113- s1 ---> c1 ---> v1 ---> v3 <--- v2 <--- c2 <--- s2
119+ s1 ---> c1 ---> v1 ---> v3(out) <--- v2 <--- c2 <--- s2
114120
115121 Target Graph:
116- c1 ---> v1 ---> v3 <--- v2 <--- c2
122+ c1 ---> v1 ---> v3 <--- v2(out) <--- c2
117123
118124 The nodes [s1, s2] will be removed.
119125 Subgraph is None
@@ -129,30 +135,39 @@ def test_replace_null_subgraph():
129135 c2 = g1 .successors (key_to_node [s2 .key ])[0 ]
130136 r = _MockRule (g1 , None , None )
131137 expected_results = [v3 .outputs [0 ]]
138+
132139 # delete c5 s5 will fail
133140 with pytest .raises (ValueError ) as e :
134- r .replace_subgraph (None , {key_to_node [op .key ] for op in [s1 , s2 ]})
135- assert g1 .results == expected_results
136- assert set (g1 ) == {key_to_node [n .key ] for n in {s1 , c1 , v1 , s2 , c2 , v2 , v3 }}
137- expected_edges = {
138- s1 : [c1 ],
139- c1 : [v1 ],
140- v1 : [v3 ],
141- s2 : [c2 ],
142- c2 : [v2 ],
143- v2 : [v3 ],
144- v3 : [],
145- }
146- for pred , successors in expected_edges .items ():
147- pred_node = key_to_node [pred .key ]
141+ r .replace_subgraph (
142+ None , {key_to_node [op .key ] for op in [s1 , s2 ]}, None , [v2 .outputs [0 ]]
143+ )
144+
145+ assert g1 .results == expected_results
146+ assert set (g1 ) == {key_to_node [n .key ] for n in {s1 , c1 , v1 , s2 , c2 , v2 , v3 }}
147+ expected_edges = {
148+ s1 : [c1 ],
149+ c1 : [v1 ],
150+ v1 : [v3 ],
151+ s2 : [c2 ],
152+ c2 : [v2 ],
153+ v2 : [v3 ],
154+ v3 : [],
155+ }
156+ for pred , successors in expected_edges .items ():
157+ pred_node = key_to_node [pred .key ]
148158 assert g1 .count_successors (pred_node ) == len (successors )
149159 for successor in successors :
150160 assert g1 .has_successor (pred_node , key_to_node [successor .key ])
151161
152162 c1 .inputs .clear ()
153163 c2 .inputs .clear ()
154- r .replace_subgraph (None , {key_to_node [op .key ] for op in [s1 , s2 ]})
155- assert g1 .results == expected_results
164+ r .replace_subgraph (
165+ None ,
166+ {key_to_node [op .key ] for op in [s1 , s2 ]},
167+ [v2 .outputs [0 ]],
168+ [v3 .outputs [0 ]],
169+ )
170+ assert g1 .results == [v2 .outputs [0 ]]
156171 assert set (g1 ) == {key_to_node [n .key ] for n in {c1 , v1 , c2 , v2 , v3 }}
157172 expected_edges = {
158173 c1 : [v1 ],
@@ -198,7 +213,7 @@ def test_replace_subgraph_without_removing_nodes():
198213 c2 = g1 .successors (key_to_node [s2 .key ])[0 ]
199214 c3 = g2 .successors (key_to_node [s3 .key ])[0 ]
200215 r = _MockRule (g1 , None , None )
201- r .replace_subgraph (g2 , None , expected_results )
216+ r .replace_subgraph (g2 , None , [ v3 . outputs [ 0 ]], None )
202217 assert g1 .results == expected_results
203218 assert set (g1 ) == {
204219 key_to_node [n .key ] for n in {s1 , c1 , v1 , s2 , c2 , v2 , s3 , c3 , v3 , v4 }
0 commit comments