Skip to content

Commit b2bfaea

Browse files
committed
pass only int64 as shape to tile, exapand dims op
1 parent 9875243 commit b2bfaea

File tree

4 files changed

+41
-25
lines changed

4 files changed

+41
-25
lines changed

tests/test_backend.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,12 +173,15 @@ def test_trig_ops(self):
173173

174174
@unittest.skipIf(BACKEND in ["caffe2", "onnxmsrt"], "not supported correctly in caffe2")
175175
def test_multinomial(self):
176-
x_val = make_xval([3, 4])
176+
shape = [2, 10]
177+
x_val = np.ones(np.prod(shape)).astype("float32").reshape(shape)
177178
x = tf.placeholder(tf.float32, shape=x_val.shape, name=_TFINPUT)
178179
op = tf.multinomial(x, 2, output_dtype=tf.int64)
179180
output = tf.identity(op, name=_TFOUTPUT)
180181
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
181-
self.assertAllClose(expected, actual, rtol=1e-06)
182+
# since returned indexes are random we can only check type and shape
183+
self.assertEqual(expected.dtype, actual.dtype)
184+
self.assertAllClose(expected.shape, actual.shape)
182185

183186
def test_maxppol(self):
184187
x_val = make_xval((1, 4, 4, 1))

tests/test_graph.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
from tf2onnx.graph_matcher import *
1515

1616

17+
_TENSORFLOW_DOMAIN = "ai.onnx.converters.tensorflow"
18+
19+
1720
def onnx_to_graphviz(g):
1821
g2 = gv.Digraph()
1922
for node in g.get_nodes():
@@ -266,7 +269,7 @@ def print_handler(ctx, node, name, args):
266269
# becomes:
267270
# T output = Identity(T Input)
268271
node.type = "Identity"
269-
node.domain = "tf"
272+
node.domain = _TENSORFLOW_DOMAIN
270273
del node.input[1:]
271274
return node
272275

@@ -276,7 +279,7 @@ def print_handler(ctx, node, name, args):
276279
_ = tf.identity(x_, name="output")
277280
g = process_tf_graph(sess.graph,
278281
custom_op_handlers={"Print": print_handler},
279-
extra_opset=helper.make_opsetid("tf", 1))
282+
extra_opset=helper.make_opsetid(_TENSORFLOW_DOMAIN, 1))
280283
self.assertEqual(
281284
'digraph { Print [op_type=Identity] output [op_type=Identity] input1:0 -> Print Print:0 -> output }',
282285
onnx_to_graphviz(g))

tf2onnx/convert.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
from onnx import helper
1616

1717

18+
_TENSORFLOW_DOMAIN = "ai.onnx.converters.tensorflow"
19+
20+
1821
def get_args():
1922
"""Parse commandline."""
2023
parser = argparse.ArgumentParser()
@@ -45,7 +48,7 @@ def get_args():
4548

4649

4750
def default_custom_op_handler(ctx, node, name, args):
48-
node.domain = "tf"
51+
node.domain = _TENSORFLOW_DOMAIN
4952
return node
5053

5154

@@ -61,7 +64,7 @@ def main():
6164
if args.custom_ops:
6265
# default custom ops for tensorflow-onnx are in the "tf" namespace
6366
custom_ops = {op: default_custom_op_handler for op in args.custom_ops.split(",")}
64-
extra_opset = [helper.make_opsetid("tf", 1)]
67+
extra_opset = [helper.make_opsetid(_TENSORFLOW_DOMAIN, 1)]
6568
else:
6669
custom_ops = {}
6770
extra_opset = None

tf2onnx/tfonnx.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,25 @@ def tensorflow_to_onnx(graph):
105105
return onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes
106106

107107

108-
# pylint: disable=W0613,C0111,W0612
108+
def _convert_shapenode_to_int64(ctx, node, input_number):
109+
shape_node = node.inputs[1]
110+
name = node.input[1]
111+
if shape_node.is_const():
112+
# if it is a const, change the const to be int64
113+
shape = shape_node.get_tensor_value()
114+
shape = np.array(list(shape), dtype=np.int64)
115+
onnx_tensor = numpy_helper.from_array(shape, name)
116+
ctx._initializers[name] = onnx_tensor
117+
shape_node.set_attr("value", onnx_tensor)
118+
return [node]
119+
else:
120+
op_name = utils.make_name(node.name)
121+
cast_op = ctx.insert_new_node_on_input(node, "Cast", name, name=op_name)
122+
cast_op.set_attr("to", onnx_pb.TensorProto.INT64)
123+
ctx.copy_shape(name, op_name + ":0")
124+
return [cast_op, node]
109125

126+
# pylint: disable=W0613,C0111,W0612
110127

111128
def no_op(ctx, node, name, args):
112129
"""Skip node."""
@@ -255,23 +272,8 @@ def reshape_op(ctx, node, name, args):
255272

256273

257274
def reshape_op5(ctx, node, name, args):
258-
shape_node = node.inputs[1]
259275
# onnx wants reshape.input[1] to have the value be int64 which is not the case for tensorflow.
260-
name = node.input[1]
261-
if shape_node.is_const():
262-
# if it is a const, change the const to be int64
263-
shape = shape_node.get_tensor_value()
264-
shape = np.array(list(shape), dtype=np.int64)
265-
onnx_tensor = numpy_helper.from_array(shape, name)
266-
ctx._initializers[name] = onnx_tensor
267-
shape_node.set_attr("value", onnx_tensor)
268-
return node
269-
else:
270-
op_name = utils.make_name(node.name)
271-
cast_op = ctx.insert_new_node_on_input(node, "Cast", name, name=op_name)
272-
cast_op.set_attr("to", onnx_pb.TensorProto.INT64)
273-
ctx.copy_shape(name, op_name + ":0")
274-
return [cast_op, node]
276+
return _convert_shapenode_to_int64(ctx, node, 1)
275277

276278

277279
NCHW_TO_NHWC = [0, 2, 3, 1]
@@ -690,7 +692,7 @@ def expanddims_op(ctx, node, name, args):
690692
def expanddims_op7(ctx, node, name, args):
691693
shape = ctx.get_shape(node.output[0])
692694
shape_name = utils.make_name(node.name)
693-
shape_node = ctx.make_const(shape_name, "Const", np.array(shape))
695+
shape_node = ctx.make_const(shape_name, "Const", np.array(shape, dtype=np.int64))
694696
node.type = "Reshape"
695697
node.input[1] = shape_name
696698
return node
@@ -786,6 +788,11 @@ def topk_op(ctx, node, name, args):
786788
return node
787789

788790

791+
def tile_op7(ctx, node, name, args):
792+
# onnx wants shape input to be int64
793+
return _convert_shapenode_to_int64(ctx, node, 1)
794+
795+
789796
# pylint: enable=W0613,C0111,W0612
790797

791798
# map tensorflow ops to onnx ops. The format below is
@@ -881,7 +888,7 @@ def topk_op(ctx, node, name, args):
881888
}
882889

883890
_OPSET_7 = {
884-
"Tile": (direct_op, []),
891+
"Tile": (tile_op7, []),
885892
"ResizeNearestNeighbor": (upsample_op, []),
886893
"BiasAdd": (biasadd_op7, []),
887894
"BiasAddV1": (biasadd_op7, []),

0 commit comments

Comments
 (0)