Skip to content

Commit 8efd659

Browse files
committed
fix add handler when input rank == 1
check if transpose and conv have only one consumer
1 parent 7c5340c commit 8efd659

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,11 +371,27 @@ def _add_handler(self, trans, node):
371371
# if Conv or ConvTranspose's bias input is not set, then we set, otherwise, we don't set
372372
# todo: maybe we can add already set bias with the input??? try later
373373

374+
if not self._nodes_has_single_consumer_node([t_p]):
375+
self.logger.debug("Conv does not have single consumer, can not merge Conv and Add")
376+
return self._handle_node_having_branches(node)
377+
378+
if not self._nodes_has_single_consumer_node([trans]):
379+
self.logger.debug("input transpose does not have single consumer, skipping...")
380+
return False
381+
374382
target_node = node.inputs[1]
375383
numpy_val = target_node.get_tensor_value(as_list=False)
376384
# Optional 1D bias to be added to the convolution, has size of M
377385
if len(numpy_val.shape) - numpy_val.shape.count(1) > 1:
378386
return self._handle_node_having_branches(node)
387+
388+
rank = len(numpy_val.shape)
389+
utils.make_sure(rank in (1, 4), "only support bias rank = 4 or 1")
390+
# to make rank = 4
391+
if rank == 1:
392+
for _ in range(3):
393+
numpy_val = np.expand_dims(numpy_val, axis=0)
394+
379395
transposed_val = np.transpose(numpy_val, (0, 3, 1, 2))
380396
target_node.set_tensor_value(transposed_val)
381397

0 commit comments

Comments
 (0)