@@ -398,31 +398,7 @@ def _transpose_handler(self, trans, node):
398
398
return False
399
399
400
400
def _maxmin_handler (self , trans , node ):
401
- input_index = self ._get_input_index_for_trans (node , trans )
402
- all_other_inputs = [input_id for i , input_id in enumerate (node .input ) if i != input_index ]
403
-
404
- all_other_inputs_const = all ([self ._g .get_node_by_output (i ).is_const () for i in all_other_inputs ])
405
- if all_other_inputs_const is False :
406
- return False
407
-
408
- shapes = [len (self ._g .get_shape (i )) for i in all_other_inputs ]
409
- shapes_not_one_and_four = [s for s in shapes if s not in [1 , 4 ]]
410
- if shapes_not_one_and_four :
411
- return False
412
-
413
- for i in all_other_inputs :
414
- target_node = self ._g .get_node_by_output (i )
415
- numpy_val = target_node .get_tensor_value (as_list = False )
416
- rank = numpy_val .ndim
417
- if rank == 4 :
418
- transposed_val = np .transpose (numpy_val , (0 , 3 , 1 , 2 ))
419
- target_node .set_tensor_value (transposed_val )
420
- elif rank == 1 : # scalar
421
- # do nothing
422
- pass
423
- else :
424
- raise ValueError ("find rank !=1 and rank !=4, should not go here." )
425
- return self ._switch_transpose_and_node (node , trans )
401
+ return self ._handle_node_having_branches (node )
426
402
427
403
def _mul_handler (self , trans , node ):
428
404
multiplier_input_id = None
0 commit comments