Skip to content

Commit 84cfa03

Browse files
authored
Merge pull request #748 from jignparm/jignparm/fix_pad_for_dynamic_input
Fix Pad op for dynamic input (opset 11).
2 parents b25b9d4 + 0962a7e commit 84cfa03

File tree

3 files changed

+33
-16
lines changed

3 files changed

+33
-16
lines changed

tf2onnx/graph.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1067,7 +1067,9 @@ def insert_new_node_on_input(self, node, op_type, input_name, name=None, domain=
10671067
Args:
10681068
node: we want to replace the input for this node
10691069
op_type: type for new operation
1070-
input_name: the names of the outputs above us
1070+
input_name: the name(s) of the outputs above us
1071+
if scalar, new node placed above input_name
1072+
if list, new node placed above input_name[0]. list is inputs into new node
10711073
name: the name of the new op
10721074
kwargs: attributes of the new node
10731075
@@ -1077,9 +1079,12 @@ def insert_new_node_on_input(self, node, op_type, input_name, name=None, domain=
10771079
if name is None:
10781080
name = utils.make_name(node.name)
10791081
new_output = port_name(name)
1080-
new_node = self.make_node(op_type, [input_name], attr=kwargs, outputs=[new_output], name=name, domain=domain)
1082+
if not isinstance(input_name, list):
1083+
input_name = [input_name]
1084+
1085+
new_node = self.make_node(op_type, input_name, attr=kwargs, outputs=[new_output], name=name, domain=domain)
10811086
for i, n in enumerate(node.input):
1082-
if n == input_name:
1087+
if n == input_name[0]:
10831088
node.input[i] = new_output
10841089
break
10851090
return new_node

tf2onnx/onnx_opset/generator.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
from tf2onnx import utils
1717
from tf2onnx.handler import tf_op
1818

19-
2019
logger = logging.getLogger(__name__)
2120

21+
2222
# pylint: disable=unused-argument,missing-docstring
2323

2424
@tf_op(["Const", "ConstV2"])
@@ -102,6 +102,15 @@ def version_9(cls, ctx, node, **kwargs):
102102
node.set_attr("value", value_proto)
103103
del node.input[1]
104104

105+
@classmethod
106+
def version_11(cls, ctx, node, **kwargs):
107+
# cls.version_7(ctx, node, **kwargs)
108+
node.type = "Expand"
109+
node.input = [node.input[1], node.input[0]]
110+
# cast shape to int64 if needed
111+
if ctx.get_dtype(node.input[1]) != onnx_pb.TensorProto.INT64:
112+
ctx.insert_new_node_on_input(node, "Cast", node.input[1], to=onnx_pb.TensorProto.INT64)
113+
105114

106115
@tf_op("Multinomial")
107116
class Multinomial:

tf2onnx/onnx_opset/nn.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ def version_1(cls, ctx, node, **kwargs):
314314
k_h, k_w, k_input_channels, k_channel_multiplier = kernel_shape
315315
if k_input_channels < 1:
316316
raise ValueError("input channel must be positive")
317-
k_output_channels = k_input_channels * k_channel_multiplier
317+
k_output_channels = k_input_channels * k_channel_multiplier
318318

319319
node.set_attr("kernel_shape", [k_h, k_w])
320320
strides = conv_dims_attr(node, "strides")
@@ -476,13 +476,16 @@ def version_11(cls, ctx, node, **kwargs):
476476
if mode not in [None, "constant", "reflect"]:
477477
raise ValueError(mode + " pad mode is not supported")
478478

479-
pads = node.inputs[1].get_tensor_value()
480-
pads = np.array(pads).transpose().flatten().astype(np.int64)
481-
node.inputs[1].set_tensor_value(pads)
479+
# pads must be int64.
480+
if ctx.get_dtype(node.input[1]) != onnx_pb.TensorProto.INT64:
481+
ctx.insert_new_node_on_input(node, "Cast", node.input[1], to=onnx_pb.TensorProto.INT64)
482+
ctx.insert_new_node_on_input(node, "Transpose", node.input[1])
483+
shape_const = ctx.make_const(utils.make_name(node.name), np.array([-1]).astype(np.int64))
484+
ctx.insert_new_node_on_input(node, "Reshape", [node.input[1], shape_const.name])
482485

483486
origin_dtype = ctx.get_dtype(node.output[0])
484-
if origin_dtype not in [TensorProto.FLOAT16, TensorProto.FLOAT,
485-
TensorProto.DOUBLE]:
487+
if origin_dtype not in [TensorProto.FLOAT, TensorProto.DOUBLE,
488+
TensorProto.INT32, TensorProto.INT64]:
486489
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[0])
487490
cast_node.set_attr("to", TensorProto.FLOAT)
488491
ctx.set_dtype(cast_node.output[0], TensorProto.FLOAT)
@@ -739,8 +742,12 @@ def version_7(cls, ctx, node, **kwargs):
739742
# 2: "loop" to generate mask matrix: generate col or row of matrix one by one
740743
g = ctx.create_new_graph_with_same_config()
741744
node_name = utils.make_name("const_zero_bool")
742-
const_zero_bool = ctx.make_const(name=node_name, np_val=np.array([[0]]).astype(np.bool))
743-
ctx.set_dtype(const_zero_bool.output[0], onnx_pb.TensorProto.BOOL)
745+
const_zero_bool = g.make_const(name=node_name, np_val=np.array([[0]]).astype(np.bool))
746+
g.set_dtype(const_zero_bool.output[0], onnx_pb.TensorProto.BOOL)
747+
748+
g.add_graph_input("trip", onnx_pb.TensorProto.INT64, [])
749+
g.add_graph_input("cond", onnx_pb.TensorProto.BOOL, [])
750+
g.add_graph_input("line", onnx_pb.TensorProto.BOOL, [-1, -1])
744751

745752
# shift right the line and add zero at the left.
746753
new_line = g.make_node(op_type="Concat", inputs=[const_zero_bool.output[0], "line"],
@@ -754,10 +761,6 @@ def version_7(cls, ctx, node, **kwargs):
754761
g.make_node("Identity", ["line"], outputs=["res"])
755762
g.make_node("Identity", [slice_node], outputs=["line_out"])
756763

757-
g.add_graph_input("trip", onnx_pb.TensorProto.INT64, [])
758-
g.add_graph_input("cond", onnx_pb.TensorProto.BOOL, [])
759-
g.add_graph_input("line", onnx_pb.TensorProto.BOOL, [-1, -1])
760-
761764
g.add_graph_output("cond_out", onnx_pb.TensorProto.BOOL, [])
762765
g.add_graph_output("line_out", onnx_pb.TensorProto.BOOL, [-1, -1])
763766
g.add_graph_output("res", onnx_pb.TensorProto.BOOL, [-1, -1])

0 commit comments

Comments
 (0)