Skip to content

Commit 3a00d06

Browse files
committed
Update tf.fill() to use dynamic size instead of const.
Fix MatrixBandPart to avoid dangling input references
1 parent 9dbd588 commit 3a00d06

File tree

3 files changed

+25
-12
lines changed

3 files changed

+25
-12
lines changed

tf2onnx/graph.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1067,7 +1067,9 @@ def insert_new_node_on_input(self, node, op_type, input_name, name=None, domain=
10671067
Args:
10681068
node: we want to replace the input for this node
10691069
op_type: type for new operation
1070-
input_name: the names of the outputs above us
1070+
input_name: the name(s) of the outputs above us
1071+
if scalar, new node placed above input_name
1072+
if list, new node placed above input_name[0]. list is inputs into new node
10711073
name: the name of the new op
10721074
kwargs: attributes of the new node
10731075
@@ -1077,9 +1079,12 @@ def insert_new_node_on_input(self, node, op_type, input_name, name=None, domain=
10771079
if name is None:
10781080
name = utils.make_name(node.name)
10791081
new_output = port_name(name)
1080-
new_node = self.make_node(op_type, [input_name], attr=kwargs, outputs=[new_output], name=name, domain=domain)
1082+
if type(input_name) is not list:
1083+
input_name = [input_name]
1084+
1085+
new_node = self.make_node(op_type, input_name, attr=kwargs, outputs=[new_output], name=name, domain=domain)
10811086
for i, n in enumerate(node.input):
1082-
if n == input_name:
1087+
if n == input_name[0]:
10831088
node.input[i] = new_output
10841089
break
10851090
return new_node

tf2onnx/onnx_opset/generator.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
from tf2onnx import utils
1717
from tf2onnx.handler import tf_op
1818

19-
2019
logger = logging.getLogger(__name__)
2120

21+
2222
# pylint: disable=unused-argument,missing-docstring
2323

2424
@tf_op(["Const", "ConstV2"])
@@ -102,6 +102,15 @@ def version_9(cls, ctx, node, **kwargs):
102102
node.set_attr("value", value_proto)
103103
del node.input[1]
104104

105+
@classmethod
106+
def version_11(cls, ctx, node, **kwargs):
107+
# cls.version_7(ctx, node, **kwargs)
108+
node.type = "Expand"
109+
node.input = [node.input[1], node.input[0]]
110+
# cast shape to int64 if needed
111+
if ctx.get_dtype(node.input[1]) != onnx_pb.TensorProto.INT64:
112+
ctx.insert_new_node_on_input(node, "Cast", node.input[1], to=onnx_pb.TensorProto.INT64)
113+
105114

106115
@tf_op("Multinomial")
107116
class Multinomial:

tf2onnx/onnx_opset/nn.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -452,9 +452,8 @@ def version_11(cls, ctx, node, **kwargs):
452452
if ctx.get_dtype(node.input[1]) != onnx_pb.TensorProto.INT64:
453453
ctx.insert_new_node_on_input(node, "Cast", node.input[1], to=onnx_pb.TensorProto.INT64)
454454
ctx.insert_new_node_on_input(node, "Transpose", node.input[1])
455-
reshape = ctx.insert_new_node_on_input(node, "Reshape", node.input[1])
456455
shape_const = ctx.make_const(utils.make_name(node.name), np.array([-1]).astype(np.int64))
457-
reshape.input = [reshape.input[0], shape_const.name]
456+
ctx.insert_new_node_on_input(node, "Reshape", [node.input[1], shape_const.name])
458457

459458
origin_dtype = ctx.get_dtype(node.output[0])
460459
if origin_dtype not in [TensorProto.FLOAT, TensorProto.DOUBLE,
@@ -715,8 +714,12 @@ def version_7(cls, ctx, node, **kwargs):
715714
# 2: "loop" to generate mask matrix: generate col or row of matrix one by one
716715
g = ctx.create_new_graph_with_same_config()
717716
node_name = utils.make_name("const_zero_bool")
718-
const_zero_bool = ctx.make_const(name=node_name, np_val=np.array([[0]]).astype(np.bool))
719-
ctx.set_dtype(const_zero_bool.output[0], onnx_pb.TensorProto.BOOL)
717+
const_zero_bool = g.make_const(name=node_name, np_val=np.array([[0]]).astype(np.bool))
718+
g.set_dtype(const_zero_bool.output[0], onnx_pb.TensorProto.BOOL)
719+
720+
g.add_graph_input("trip", onnx_pb.TensorProto.INT64, [])
721+
g.add_graph_input("cond", onnx_pb.TensorProto.BOOL, [])
722+
g.add_graph_input("line", onnx_pb.TensorProto.BOOL, [-1, -1])
720723

721724
# shift right the line and add zero at the left.
722725
new_line = g.make_node(op_type="Concat", inputs=[const_zero_bool.output[0], "line"],
@@ -730,10 +733,6 @@ def version_7(cls, ctx, node, **kwargs):
730733
g.make_node("Identity", ["line"], outputs=["res"])
731734
g.make_node("Identity", [slice_node], outputs=["line_out"])
732735

733-
g.add_graph_input("trip", onnx_pb.TensorProto.INT64, [])
734-
g.add_graph_input("cond", onnx_pb.TensorProto.BOOL, [])
735-
g.add_graph_input("line", onnx_pb.TensorProto.BOOL, [-1, -1])
736-
737736
g.add_graph_output("cond_out", onnx_pb.TensorProto.BOOL, [])
738737
g.add_graph_output("line_out", onnx_pb.TensorProto.BOOL, [-1, -1])
739738
g.add_graph_output("res", onnx_pb.TensorProto.BOOL, [-1, -1])

0 commit comments

Comments
 (0)