Skip to content

Commit 6eae86e

Browse files
authored
Merge pull request #644 from lei-Qiao/add_handler
Fix Add handler
2 parents 1da54e1 + e67d9bd commit 6eae86e

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -383,16 +383,17 @@ def _add_handler(self, trans, node):
383383
numpy_val = target_node.get_tensor_value(as_list=False)
384384
# Optional 1D bias to be added to the convolution, has size of M
385385
if len(numpy_val.shape) - numpy_val.shape.count(1) > 1:
386+
self.logger.debug("Bias is not 1D, can not merge Conv and Add")
386387
return self._handle_node_having_branches(node)
387388

388-
rank = len(numpy_val.shape)
389-
utils.make_sure(rank in (1, 4), "only support bias rank = 4 or 1")
390-
# to make rank = 4
391-
if rank == 1:
392-
numpy_val = numpy_val.reshape((1, 1, 1, numpy_val.shape[0]))
389+
bias_size = max(numpy_val.shape)
390+
size_m = t_p.inputs[1].output_shapes[0][0]
391+
if bias_size != size_m:
392+
self.logger.debug("Bias size is not M, can not merge Conv and Add")
393+
return self._handle_node_having_branches(node)
393394

394-
transposed_val = np.transpose(numpy_val, (0, 3, 1, 2))
395-
target_node.set_tensor_value(transposed_val)
395+
target_val = numpy_val.reshape(bias_size)
396+
target_node.set_tensor_value(target_val)
396397

397398
conv_inputs = [t_p.input[0], t_p.input[1], node.input[1]]
398399
conv_node = self._g.make_node(t_p.type, conv_inputs, attr=t_p.attr_onnx)

0 commit comments

Comments
 (0)