@@ -24,8 +24,8 @@ class _MockRule(OptimizationRule):
2424 def apply (self ) -> bool :
2525 pass
2626
27- def replace_subgraph (self , graph , removed_nodes , new_results = None ):
28- self ._replace_subgraph (graph , removed_nodes , new_results )
27+ def replace_subgraph (self , graph , nodes_to_remove , new_results , results_to_remove ):
28+ self ._replace_subgraph (graph , nodes_to_remove , new_results , results_to_remove )
2929
3030
3131def test_replace_tileable_subgraph ():
@@ -78,11 +78,15 @@ def test_replace_tileable_subgraph():
7878 c2 = g1 .successors (key_to_node [s2 .key ])[0 ]
7979 c5 = g1 .successors (key_to_node [s5 .key ])[0 ]
8080
81- expected_results = [v8 .outputs [0 ]]
81+ new_results = [v8 .outputs [0 ]]
82+ removed_results = [
83+ v6 .outputs [0 ],
84+ v8 .outputs [0 ], # v8.outputs[0] is not in the original results, so we ignore it.
85+ ]
8286 r .replace_subgraph (
83- g2 , {key_to_node [op .key ] for op in [v3 , v4 , v6 ]}, expected_results
87+ g2 , {key_to_node [op .key ] for op in [v3 , v4 , v6 ]}, new_results , removed_results
8488 )
85- assert g1 .results == expected_results
89+ assert g1 .results == new_results
8690
8791 expected_nodes = {s1 , c1 , v1 , s2 , c2 , v2 , s5 , c5 , v5 , v7 , v8 }
8892 assert set (g1 ) == {key_to_node [n .key ] for n in expected_nodes }
@@ -110,10 +114,10 @@ def test_replace_tileable_subgraph():
110114def test_replace_null_subgraph ():
111115 """
112116 Original Graph:
113- s1 ---> c1 ---> v1 ---> v3 <--- v2 <--- c2 <--- s2
117+ s1 ---> c1 ---> v1 ---> v3(out) <--- v2 <--- c2 <--- s2
114118
115119 Target Graph:
116- c1 ---> v1 ---> v3 <--- v2 <--- c2
120+ c1 ---> v1 ---> v3 <--- v2(out) <--- c2
117121
118122 The nodes [s1, s2] will be removed.
119123 Subgraph is None
@@ -129,30 +133,39 @@ def test_replace_null_subgraph():
129133 c2 = g1 .successors (key_to_node [s2 .key ])[0 ]
130134 r = _MockRule (g1 , None , None )
131135 expected_results = [v3 .outputs [0 ]]
136+
132137 # delete c5 s5 will fail
133138 with pytest .raises (ValueError ) as e :
134- r .replace_subgraph (None , {key_to_node [op .key ] for op in [s1 , s2 ]})
135- assert g1 .results == expected_results
136- assert set (g1 ) == {key_to_node [n .key ] for n in {s1 , c1 , v1 , s2 , c2 , v2 , v3 }}
137- expected_edges = {
138- s1 : [c1 ],
139- c1 : [v1 ],
140- v1 : [v3 ],
141- s2 : [c2 ],
142- c2 : [v2 ],
143- v2 : [v3 ],
144- v3 : [],
145- }
146- for pred , successors in expected_edges .items ():
147- pred_node = key_to_node [pred .key ]
139+ r .replace_subgraph (
140+ None , {key_to_node [op .key ] for op in [s1 , s2 ]}, None , [v2 .outputs [0 ]]
141+ )
142+
143+ assert g1 .results == expected_results
144+ assert set (g1 ) == {key_to_node [n .key ] for n in {s1 , c1 , v1 , s2 , c2 , v2 , v3 }}
145+ expected_edges = {
146+ s1 : [c1 ],
147+ c1 : [v1 ],
148+ v1 : [v3 ],
149+ s2 : [c2 ],
150+ c2 : [v2 ],
151+ v2 : [v3 ],
152+ v3 : [],
153+ }
154+ for pred , successors in expected_edges .items ():
155+ pred_node = key_to_node [pred .key ]
148156 assert g1 .count_successors (pred_node ) == len (successors )
149157 for successor in successors :
150158 assert g1 .has_successor (pred_node , key_to_node [successor .key ])
151159
152160 c1 .inputs .clear ()
153161 c2 .inputs .clear ()
154- r .replace_subgraph (None , {key_to_node [op .key ] for op in [s1 , s2 ]})
155- assert g1 .results == expected_results
162+ r .replace_subgraph (
163+ None ,
164+ {key_to_node [op .key ] for op in [s1 , s2 ]},
165+ [v2 .outputs [0 ]],
166+ [v3 .outputs [0 ]],
167+ )
168+ assert g1 .results == [v2 .outputs [0 ]]
156169 assert set (g1 ) == {key_to_node [n .key ] for n in {c1 , v1 , c2 , v2 , v3 }}
157170 expected_edges = {
158171 c1 : [v1 ],
@@ -198,7 +211,7 @@ def test_replace_subgraph_without_removing_nodes():
198211 c2 = g1 .successors (key_to_node [s2 .key ])[0 ]
199212 c3 = g2 .successors (key_to_node [s3 .key ])[0 ]
200213 r = _MockRule (g1 , None , None )
201- r .replace_subgraph (g2 , None , expected_results )
214+ r .replace_subgraph (g2 , None , [ v3 . outputs [ 0 ]], None )
202215 assert g1 .results == expected_results
203216 assert set (g1 ) == {
204217 key_to_node [n .key ] for n in {s1 , c1 , v1 , s2 , c2 , v2 , s3 , c3 , v3 , v4 }
0 commit comments