Skip to content

Commit 2a5e655

Browse files
author
wayuanho
authored
Merge pull request #637 from lei-Qiao/add_handler
fix Add handler
2 parents e59907b + c7d9e8f commit 2a5e655

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,11 +371,26 @@ def _add_handler(self, trans, node):
371371
# if Conv or ConvTranspose's bias input is not set, then we set, otherwise, we don't set
372372
# todo: maybe we can add already set bias with the input??? try later
373373

374+
if not self._nodes_has_single_consumer_node([t_p]):
375+
self.logger.debug("Conv does not have single consumer, can not merge Conv and Add")
376+
return self._handle_node_having_branches(node)
377+
378+
if not self._nodes_has_single_consumer_node([trans]):
379+
self.logger.debug("input transpose does not have single consumer, skipping...")
380+
return False
381+
374382
target_node = node.inputs[1]
375383
numpy_val = target_node.get_tensor_value(as_list=False)
376384
# Optional 1D bias to be added to the convolution, has size of M
377385
if len(numpy_val.shape) - numpy_val.shape.count(1) > 1:
378386
return self._handle_node_having_branches(node)
387+
388+
rank = len(numpy_val.shape)
389+
utils.make_sure(rank in (1, 4), "only support bias rank = 4 or 1")
390+
# to make rank = 4
391+
if rank == 1:
392+
numpy_val = numpy_val.reshape((1, 1, 1, numpy_val.shape[0]))
393+
379394
transposed_val = np.transpose(numpy_val, (0, 3, 1, 2))
380395
target_node.set_tensor_value(transposed_val)
381396

0 commit comments

Comments
 (0)