@@ -24,8 +24,8 @@ class _MockRule(OptimizationRule):
2424 def apply (self ) -> bool :
2525 pass
2626
27- def replace_subgraph (self , graph , nodes_to_remove , new_results , results_to_remove ):
28- self ._replace_subgraph (graph , nodes_to_remove , new_results , results_to_remove )
27+ def replace_subgraph (self , graph , nodes_to_remove , new_results = None ):
28+ self ._replace_subgraph (graph , nodes_to_remove , new_results )
2929
3030
3131def test_replace_tileable_subgraph ():
@@ -61,8 +61,9 @@ def test_replace_tileable_subgraph():
6161 g1 = v6 .build_graph ()
6262 v7 = v1 .sub (v2 )
6363 v8 = v7 .add (v5 )
64+ v8 ._key = v6 .key
65+ v8 .outputs [0 ]._key = v6 .key
6466 g2 = v8 .build_graph ()
65-
6667 # Here we use a trick way to construct the subgraph for test only
6768 key_to_node = dict ()
6869 for node in g2 .iter_nodes ():
@@ -79,15 +80,12 @@ def test_replace_tileable_subgraph():
7980 c5 = g1 .successors (key_to_node [s5 .key ])[0 ]
8081
8182 new_results = [v8 .outputs [0 ]]
82- removed_results = [
83- v6 .outputs [0 ],
84- v8 .outputs [0 ], # v8.outputs[0] is not in the original results, so we ignore it.
85- ]
86- r .replace_subgraph (
87- g2 , {key_to_node [op .key ] for op in [v3 , v4 , v6 ]}, new_results , removed_results
88- )
83+ r .replace_subgraph (g2 , {key_to_node [op .key ] for op in [v3 , v4 , v6 ]}, new_results )
8984 assert g1 .results == new_results
90-
85+ for node in g1 .iter_nodes ():
86+ if node .key == v8 .key :
87+ key_to_node [v8 .key ] = node
88+ break
9189 expected_nodes = {s1 , c1 , v1 , s2 , c2 , v2 , s5 , c5 , v5 , v7 , v8 }
9290 assert set (g1 ) == {key_to_node [n .key ] for n in expected_nodes }
9391
@@ -117,7 +115,7 @@ def test_replace_null_subgraph():
117115 s1 ---> c1 ---> v1 ---> v3(out) <--- v2 <--- c2 <--- s2
118116
119117 Target Graph:
120- c1 ---> v1 ---> v3 <--- v2(out) <--- c2
118+ c1 ---> v1 ---> v3(out) <--- v2 <--- c2
121119
122120 The nodes [s1, s2] will be removed.
123121 Subgraph is None
@@ -137,7 +135,7 @@ def test_replace_null_subgraph():
137135 # delete c5 s5 will fail
138136 with pytest .raises (ValueError ):
139137 r .replace_subgraph (
140- None , {key_to_node [op .key ] for op in [s1 , s2 ]}, None , [v2 .outputs [0 ]]
138+ None , {key_to_node [op .key ] for op in [s1 , s2 ]}, [v2 .outputs [0 ]]
141139 )
142140
143141 assert g1 .results == expected_results
@@ -161,11 +159,9 @@ def test_replace_null_subgraph():
161159 c2 .inputs .clear ()
162160 r .replace_subgraph (
163161 None ,
164- {key_to_node [op .key ] for op in [s1 , s2 ]},
165- [v2 .outputs [0 ]],
166- [v3 .outputs [0 ]],
162+ {key_to_node [op .key ] for op in [s1 , s2 ]}
167163 )
168- assert g1 .results == [ v2 . outputs [ 0 ]]
164+ assert g1 .results == expected_results
169165 assert set (g1 ) == {key_to_node [n .key ] for n in {c1 , v1 , c2 , v2 , v3 }}
170166 expected_edges = {
171167 c1 : [v1 ],
@@ -206,12 +202,12 @@ def test_replace_subgraph_without_removing_nodes():
206202 key_to_node = {
207203 node .key : node for node in itertools .chain (g1 .iter_nodes (), g2 .iter_nodes ())
208204 }
209- expected_results = [v3 . outputs [ 0 ], v4 .outputs [0 ]]
205+ expected_results = [v4 .outputs [0 ]]
210206 c1 = g1 .successors (key_to_node [s1 .key ])[0 ]
211207 c2 = g1 .successors (key_to_node [s2 .key ])[0 ]
212208 c3 = g2 .successors (key_to_node [s3 .key ])[0 ]
213209 r = _MockRule (g1 , None , None )
214- r .replace_subgraph (g2 , None , [ v3 . outputs [ 0 ]], None )
210+ r .replace_subgraph (g2 , None )
215211 assert g1 .results == expected_results
216212 assert set (g1 ) == {
217213 key_to_node [n .key ] for n in {s1 , c1 , v1 , s2 , c2 , v2 , s3 , c3 , v3 , v4 }
0 commit comments