@@ -129,6 +129,7 @@ def in_redirect(self, old_name, name):
129129 self .input [key ] = name
130130
131131 def out_redirect (self , old_name , name ):
132+ assert self .in_or_out
132133 if old_name in self .output :
133134 self .output [old_name ] = name
134135 else :
@@ -191,10 +192,10 @@ def build_from_onnx(onnx_nodes, nchw_inputs, inputs, outputs):
191192 if var_ in nchw_inputs :
192193 nnode = LinkedNode (
193194 helper .make_node (
194- 'Transpose' ,
195- [var_ ],
196- [new_output ],
197- perm = [0 , 2 , 3 , 1 ]))
195+ 'Transpose' ,
196+ [var_ ],
197+ [new_output ],
198+ perm = [0 , 2 , 3 , 1 ]))
198199 var_map [new_output ] = nnode
199200 nnode .add_precedence (target , var_ )
200201 n_ .in_redirect (var_ , new_output )
@@ -236,6 +237,10 @@ def debug_print(node_list):
236237
237238
238239class Solution (object ):
240+ """
241+ Solution is the base class for solutions, and it has a basic function is to
242+ delete the node range of (begin, begin_n, end_p, end), where 'begin' and 'end' are excluded.
243+ """
239244 def __init__ (self , begin , begin_n , end_p , end ):
240245 self .begin = begin
241246 self .begin_n = begin_n
@@ -255,23 +260,64 @@ def is_useless_transpose(perm):
255260 return perm == list (six .moves .range (len (perm )))
256261
257262 @staticmethod
258- def delete_node (node_list , begin , node , end ): # type: ([],LinkedNode, LinkedNode, LinkedNode)->[]
263+ def delete_node_nto1 (node_list , begin , node , end ): # type: ([],LinkedNode, LinkedNode, LinkedNode)->[]
264+ """
265+ delete the node which has n-input and 1-output
266+ """
267+ if begin is None :
268+ assert node is not None
269+ begin = node .precedence
270+ elif not isinstance (begin , list ):
271+ begin = [begin ]
272+
259273 if end .in_or_out :
274+ # if the end is output node, the output name will be kept to avoid the model output name updating.
275+ for nb_ in begin :
276+ nb_ .out_redirect (node .single_input , node .single_output )
277+ else :
278+ for nb_ in begin :
279+ target_var_name = node .single_input
280+ assert target_var_name in nb_ .output .values () # since the output info never be updated, except the final.
281+ end .in_redirect (node .single_output , target_var_name )
282+
283+ for nb_ in begin :
284+ nb_ .successor = [end if v_ == node else v_ for v_ in nb_ .successor ]
285+ end .precedence = [v_ for v_ in end .precedence if v_ != node ] + node .precedence
286+
287+ node_list .remove (node )
288+ return node_list
289+
290+ @staticmethod
291+ def delete_node_1ton (node_list , begin , node , end ): # type: ([],LinkedNode, LinkedNode, LinkedNode)->[]
292+ """
293+ delete the node which has 1-input and n-output
294+ """
295+ if end is None :
296+ assert end is not None
297+ end = node .successor
298+ elif not isinstance (end , list ):
299+ end = [end ]
300+
301+ if any (e_ .in_or_out for e_ in end ):
260302 # if the end is output node, the output name will be kept to avoid the model output name updating.
261303 begin .out_redirect (node .single_input , node .single_output )
262304 else :
263- target_var_name = node .single_input
264- assert target_var_name in begin .output .values () # since the output info never be updated, except the final.
265- end .in_redirect (node .single_output , target_var_name )
305+ for ne_ in end :
306+ target_var_name = node .single_input
307+ # since the output info never be updated, except the final.
308+ assert target_var_name in begin .output .values ()
309+ ne_ .in_redirect (node .single_output , target_var_name )
266310
267- begin .successor = [end if v_ == node else v_ for v_ in begin .successor ]
268- end .precedence = [begin if v_ == node else v_ for v_ in end .precedence ]
311+ begin .successor = [v_ for v_ in begin .successor if v_ != node ] + node .successor
312+ for ne_ in end :
313+ ne_ .precedence = [begin if v_ == node else v_ for v_ in ne_ .precedence ]
269314
270315 node_list .remove (node )
271316 return node_list
272317
273318 @staticmethod
274- def add_siso_node (node_list , begin , end , begin_output_name , node ): # type: ([], LinkedNode, LinkedNode, string, LinkedNode)->[]
319+ def add_siso_node (node_list , begin , end , begin_output_name , node ):
320+ # type: ([], LinkedNode, LinkedNode, str, LinkedNode)->[]
275321 node .in_redirect (node .single_input , begin_output_name )
276322 end .in_redirect (begin_output_name , node .single_output )
277323 begin .successor [begin .successor .index (end )] = node
@@ -287,8 +333,11 @@ def apply(self, node_list):
287333 while node != self .end :
288334 assert len (node .successor ) == 1
289335 end = node .successor [0 ]
290- node_list = self .delete_node (node_list , self .begin , node , end )
291- node = end
336+ if self .begin :
337+ node_list = self .delete_node_nto1 (node_list , self .begin , node , end )
338+ else :
339+ node_list = self .delete_node_nto1 (node_list , self .begin , node , end )
340+ node = self .end if self .end is None else end
292341
293342 return node_list
294343
@@ -306,10 +355,10 @@ def apply(self, node_list):
306355 # node.reshape_input_for_broadcast(perm0)
307356 node = node .successor [0 ]
308357
309- node_list = self .delete_node (node_list , self .begin , self .begin_n , self .begin_n .successor [0 ])
310- node_list = self .delete_node (node_list , self .end_p .precedence [0 ], self .end_p , self .end )
358+ node_list = self .delete_node_1ton (node_list , self .begin , self .begin_n , self .begin_n .successor [0 ])
359+ node_list = self .delete_node_1ton (node_list , self .end_p .precedence [0 ], self .end_p , self .end )
311360 else :
312- node_list = self .delete_node (node_list , self .begin_n , self .end_p , self .end )
361+ node_list = self .delete_node_1ton (node_list , self .begin_n , self .end_p , self .end )
313362 self .begin_n .attribute ['perm' ] = perm_f
314363 return node_list
315364
@@ -346,7 +395,7 @@ def apply(self, node_list):
346395 FanOutSolution .number = FanOutSolution .number + 1
347396 node_list = Solution .add_siso_node (node_list , self .end_p , suc , list (suc .input .values ())[0 ], nnode )
348397
349- node_list = Solution .delete_node (node_list , self .begin , self .begin_n , self .end_p )
398+ node_list = Solution .delete_node_1ton (node_list , self .begin , self .begin_n , self .end_p )
350399 return node_list
351400
352401
@@ -368,7 +417,7 @@ def apply(self, node_list):
368417 precedence_list = self .begin .precedence .copy ()
369418 node_list = Solution .add_siso_node (node_list , self .begin , self .begin_n , list (self .begin .output .values ())[0 ], nnode )
370419 for branch in precedence_list :
371- node_list = Solution .delete_node (node_list , branch .precedence [0 ], branch , self .begin )
420+ node_list = Solution .delete_node_1ton (node_list , branch .precedence [0 ], branch , self .begin )
372421 return node_list
373422
374423
0 commit comments