Skip to content

Commit 4e7b167

Browse files
authored
Merge pull request #310 from pengwa/fix_reduce_logic
fix few issues found in real model
2 parents da2bfbd + 30fa07a commit 4e7b167

File tree

3 files changed

+38
-29
lines changed

3 files changed

+38
-29
lines changed

tests/run_pretrained_models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ def run_onnxruntime(self, name, model_proto, inputs):
205205
"""Run test against msrt-next backend."""
206206
import onnxruntime as rt
207207
model_path = utils.save_onnx_model(TMPPATH, name, inputs, model_proto, include_test_data=True)
208+
utils.save_onnx_model(TMPPATH, name, inputs, model_proto, include_test_data=False, as_text=True)
208209
print("\t\t" + model_path)
209210
m = rt.InferenceSession(model_path)
210211
results = m.run(self.output_names, inputs)

tests/test_graph.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def test_argminmax(self):
207207
_ = tf.identity(x_, name="output")
208208
g = process_tf_graph(sess.graph)
209209
self.assertEqual(
210-
'digraph { "ArgMin/dimension" [op_type=Const] input1 [op_type=Placeholder shape="[2, 3]"] ' \
210+
'digraph { "ArgMin/dimension" [op_type=Const] input1 [op_type=Placeholder shape="[2, 3]"] '
211211
'ArgMin [op_type=ArgMin] output [op_type=Identity] input1:0 -> ArgMin ArgMin:0 -> output }',
212212
onnx_to_graphviz(g))
213213

@@ -260,12 +260,12 @@ def test_conv2d(self):
260260

261261
g = process_tf_graph(sess.graph)
262262
self.assertEqual(
263-
'digraph { input1 [op_type=Placeholder shape="[1, 4, 4, 1]"] Conv2D__2 [op_type=Transpose] '
264-
'"kernel/shape" [op_type=Const] k [op_type=Const] kernel [op_type=Reshape] '
265-
'Conv2D__3 [op_type=Transpose] Conv2D [op_type=Conv] Conv2D__4 [op_type=Transpose] '
266-
'output [op_type=Identity] input1:0 -> Conv2D__2 k:0 -> kernel "kernel/shape":0 -> kernel '
267-
'kernel:0 -> Conv2D__3 Conv2D__2:0 -> Conv2D Conv2D__3:0 -> Conv2D '
268-
'Conv2D:0 -> Conv2D__4 Conv2D__4:0 -> output }',
263+
'digraph { input1 [op_type=Placeholder shape="[1, 4, 4, 1]"] Conv2D__3 [op_type=Transpose] '
264+
'"kernel/shape" [op_type=Const] kernel__2 [op_type=Cast] k [op_type=Const] '
265+
'kernel [op_type=Reshape] Conv2D__4 [op_type=Transpose] Conv2D [op_type=Conv] '
266+
'Conv2D__5 [op_type=Transpose] output [op_type=Identity] input1:0 -> Conv2D__3 '
267+
'"kernel/shape":0 -> kernel__2 k:0 -> kernel kernel__2:0 -> kernel kernel:0 -> Conv2D__4 '
268+
'Conv2D__3:0 -> Conv2D Conv2D__4:0 -> Conv2D Conv2D:0 -> Conv2D__5 Conv2D__5:0 -> output }',
269269
onnx_to_graphviz(g))
270270

271271
def test_squeeze(self):
@@ -275,7 +275,7 @@ def test_squeeze(self):
275275
_ = tf.identity(x_, name="output")
276276
g = process_tf_graph(sess.graph)
277277
self.assertEqual(
278-
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Squeeze [op_type=Squeeze] '\
278+
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Squeeze [op_type=Squeeze] '
279279
'output [op_type=Identity] input1:0 -> Squeeze Squeeze:0 -> output }',
280280
onnx_to_graphviz(g))
281281

@@ -286,7 +286,7 @@ def test_cast(self):
286286
_ = tf.identity(x_, name="output")
287287
g = process_tf_graph(sess.graph)
288288
self.assertEqual(
289-
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Cast [op_type=Cast] output [op_type=Identity] '\
289+
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Cast [op_type=Cast] output [op_type=Identity] '
290290
'input1:0 -> Cast Cast:0 -> output }',
291291
onnx_to_graphviz(g))
292292

@@ -297,9 +297,10 @@ def test_reshape(self):
297297
_ = tf.identity(x_, name="output")
298298
g = process_tf_graph(sess.graph)
299299
self.assertEqual(
300-
'digraph { "Reshape/shape" [op_type=Const] input1 [op_type=Placeholder shape="[2, 3]"] '
301-
'Reshape [op_type=Reshape] output [op_type=Identity] input1:0 -> Reshape '
302-
'"Reshape/shape":0 -> Reshape Reshape:0 -> output }',
300+
'digraph { "Reshape/shape" [op_type=Const] Reshape__2 [op_type=Cast] '
301+
'input1 [op_type=Placeholder shape="[2, 3]"] Reshape [op_type=Reshape] '
302+
'output [op_type=Identity] "Reshape/shape":0 -> Reshape__2 input1:0 -> Reshape '
303+
'Reshape__2:0 -> Reshape Reshape:0 -> output }',
303304
onnx_to_graphviz(g))
304305

305306
def test_custom_rewrite(self):

tf2onnx/tfonnx.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -147,14 +147,6 @@ def _convert_shapenode_to_int64(ctx, node, input_number):
147147
"""cast int32 shape into int64 shape."""
148148
shape_node = node.inputs[input_number]
149149
name = node.input[input_number]
150-
if shape_node.is_const():
151-
# if it is a const, change the const to be int64
152-
shape = shape_node.get_tensor_value()
153-
shape = np.array(list(shape), dtype=np.int64)
154-
shape_node.set_tensor_value(shape)
155-
ctx.set_dtype(shape_node.output[0], onnx_pb.TensorProto.INT64)
156-
ctx.copy_shape(name, shape_node.output[0])
157-
return [node]
158150

159151
cast_node = ctx.insert_new_node_on_input(node, "Cast", name)
160152
cast_node.set_attr("to", onnx_pb.TensorProto.INT64)
@@ -902,6 +894,7 @@ def pad_op(ctx, node, name, args):
902894
# or PadV2(T input, int32 paddings, T constant_value, @type Tpaddings), CONST mode - default value specified
903895
# or MirrorPad(T input, int32 paddings, @type Tpaddings, @STRING mode), other mode.
904896
# T output = Pad(T data, @STRING mode, @INTS pads, @FLOAT value)
897+
nodes = [node]
905898
paddings = np.array(node.inputs[1].get_tensor_value()).transpose().flatten()
906899
mode = node.get_attr("mode")
907900
if mode:
@@ -917,7 +910,24 @@ def pad_op(ctx, node, name, args):
917910

918911
ctx.remove_input(node, node.input[1])
919912
node.set_attr("pads", paddings)
920-
return node
913+
914+
origin_dtype = ctx.get_dtype(node.output[0])
915+
if origin_dtype not in [onnx_pb.TensorProto.FLOAT16, onnx_pb.TensorProto.FLOAT,
916+
onnx_pb.TensorProto.DOUBLE]:
917+
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[0])
918+
cast_node.set_attr("to", onnx_pb.TensorProto.FLOAT)
919+
ctx.set_dtype(cast_node.output[0], onnx_pb.TensorProto.FLOAT)
920+
ctx.copy_shape(name, cast_node.output[0])
921+
nodes.append(cast_node)
922+
923+
cast_back_node = ctx.insert_new_node_on_output("Cast", node.output[0],
924+
name=utils.make_name(node.name) + "_castback")
925+
cast_back_node.set_attr("to", origin_dtype)
926+
ctx.set_dtype(cast_back_node.output[0], origin_dtype)
927+
ctx.copy_shape(name, cast_back_node.output[0])
928+
nodes.append(cast_back_node)
929+
930+
return nodes
921931

922932

923933
def rsqrt_op(ctx, node, name, args):
@@ -1222,11 +1232,6 @@ def minmax_op(ctx, node, name, args):
12221232

12231233

12241234
def pack_op(ctx, node, name, args):
1225-
# in tf, "pack" can accept one input tensor which means doing nothing,
1226-
# so remove the node in ONNX
1227-
if len(node.inputs) == 1:
1228-
ctx.replace_all_inputs(ctx.get_nodes(), node.output[0], node.input[0])
1229-
return None
12301235

12311236
# hack to make up for the missing onnx pack op
12321237
axis = node.get_attr("axis").i
@@ -1650,13 +1655,15 @@ def reduce_logic_op(ctx, node, name, args):
16501655

16511656
utils.make_sure(all(i >= 0 for i in reduce_dim), "negative reduce axis is not supported in onnx for now")
16521657

1653-
cast = ctx.make_node(op_type="Cast", inputs=[node.input[0]], attr={"to": onnx_pb.TensorProto.INT32})
1658+
cast = ctx.make_node(op_type="Cast", inputs=[node.input[0]], attr={"to": onnx_pb.TensorProto.FLOAT})
16541659
keepdims = helper.get_attribute_value(node.get_attr("keep_dims"))
16551660
op_type = "ReduceMin" if node.type == "All" else "ReduceSum"
16561661
reduce_node = ctx.make_node(op_type=op_type, inputs=cast.output, attr={"axes": reduce_dim, "keepdims": keepdims})
1657-
res = ctx.make_node(op_type="Cast", inputs=reduce_node.output, attr={"to": onnx_pb.TensorProto.BOOL},
1662+
1663+
zero_node = ctx.make_const(utils.make_name("zero_reduce"), np.array(0, dtype=np.float32))
1664+
res = ctx.make_node(op_type="Greater", inputs=[reduce_node.output[0], zero_node.output[0]],
16581665
name=node.name, outputs=node.output)
1659-
return [cast, reduce_node, res]
1666+
return [cast, reduce_node, zero_node, res]
16601667

16611668

16621669
def zeroslike_op(ctx, node, name, args):

0 commit comments

Comments
 (0)