Skip to content

Commit b23377c

Browse files
committed
changes for maxmin handler
1 parent b6e8056 commit b23377c

File tree

1 file changed

+1
-25
lines changed

1 file changed

+1
-25
lines changed

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -398,31 +398,7 @@ def _transpose_handler(self, trans, node):
398398
return False
399399

400400
def _maxmin_handler(self, trans, node):
401-
input_index = self._get_input_index_for_trans(node, trans)
402-
all_other_inputs = [input_id for i, input_id in enumerate(node.input) if i != input_index]
403-
404-
all_other_inputs_const = all([self._g.get_node_by_output(i).is_const() for i in all_other_inputs])
405-
if all_other_inputs_const is False:
406-
return False
407-
408-
shapes = [len(self._g.get_shape(i)) for i in all_other_inputs]
409-
shapes_not_one_and_four = [s for s in shapes if s not in [1, 4]]
410-
if shapes_not_one_and_four:
411-
return False
412-
413-
for i in all_other_inputs:
414-
target_node = self._g.get_node_by_output(i)
415-
numpy_val = target_node.get_tensor_value(as_list=False)
416-
rank = numpy_val.ndim
417-
if rank == 4:
418-
transposed_val = np.transpose(numpy_val, (0, 3, 1, 2))
419-
target_node.set_tensor_value(transposed_val)
420-
elif rank == 1: # scalar
421-
# do nothing
422-
pass
423-
else:
424-
raise ValueError("find rank !=1 and rank !=4, should not go here.")
425-
return self._switch_transpose_and_node(node, trans)
401+
return self._handle_node_having_branches(node)
426402

427403
def _mul_handler(self, trans, node):
428404
multiplier_input_id = None

0 commit comments

Comments
 (0)