Skip to content

Commit bc7f40e

Browse files
authored
Merge pull request #50 from onnx/gs/onnx-1.2
work around missing onnx broadcast support for min/max
2 parents 47ef89a + f0a2462 commit bc7f40e

File tree

3 files changed

+70
-18
lines changed

3 files changed

+70
-18
lines changed

tests/test_backend.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,17 @@ def test_trig_ops(self):
173173

174174
@unittest.skipIf(BACKEND in ["caffe2", "onnxmsrt"], "not supported correctly in caffe2")
175175
def test_multinomial(self):
176+
x_val = np.array([[10., 10.]], dtype=np.float32)
177+
x = tf.placeholder(tf.float32, shape=x_val.shape, name=_TFINPUT)
178+
op = tf.multinomial(tf.log(x), 5, output_dtype=tf.int64)
179+
output = tf.identity(op, name=_TFOUTPUT)
180+
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
181+
# since returned indexes are random we can only check type and shape
182+
self.assertEqual(expected.dtype, actual.dtype)
183+
self.assertEqual(expected.shape, actual.shape)
184+
185+
@unittest.skipIf(BACKEND in ["caffe2", "onnxmsrt"], "not supported correctly in caffe2")
186+
def test_multinomial1(self):
176187
shape = [2, 10]
177188
x_val = np.ones(np.prod(shape)).astype("float32").reshape(shape)
178189
x = tf.placeholder(tf.float32, shape=x_val.shape, name=_TFINPUT)
@@ -181,7 +192,7 @@ def test_multinomial(self):
181192
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
182193
# since returned indexes are random we can only check type and shape
183194
self.assertEqual(expected.dtype, actual.dtype)
184-
self.assertAllClose(expected.shape, actual.shape)
195+
self.assertEqual(expected.shape, actual.shape)
185196

186197
def test_maxppol(self):
187198
x_val = make_xval((1, 4, 4, 1))
@@ -462,13 +473,25 @@ def test_square(self):
462473
def test_min(self):
463474
x_val1 = np.array([4.0, 16.0, 4.0, 1.6], dtype=np.float32).reshape((2, 2))
464475
x_val2 = np.array([4.0, 4.0, 4.0, 4.0], dtype=np.float32).reshape((2, 2))
465-
x1 = tf.placeholder(tf.float32, [2, 2], name=_TFINPUT)
466-
x2 = tf.placeholder(tf.float32, [2, 2], name=_TFINPUT1)
476+
x1 = tf.placeholder(tf.float32, x_val1.shape, name=_TFINPUT)
477+
x2 = tf.placeholder(tf.float32, x_val2.shape, name=_TFINPUT1)
467478
mi = tf.minimum(x1, x2)
468479
output = tf.identity(mi, name=_TFOUTPUT)
469480
actual, expected = self._run(output, {x1: x_val1, x2: x_val2}, {_INPUT: x_val1, _INPUT1: x_val2, })
470481
self.assertAllClose(expected, actual)
471482

483+
@unittest.skipIf(BACKEND in ["caffe2", "onnxmsrt"], "issue with broadcastnig scalar")
484+
def test_min_broadcast(self):
485+
# tests if the broadcast for min/max is working
486+
x_val1 = np.array([2.0, 16.0, 5.0, 1.6], dtype=np.float32).reshape((2, 2))
487+
x_val2 = np.array([4.0], dtype=np.float32)
488+
x1 = tf.placeholder(tf.float32, x_val1.shape, name=_TFINPUT)
489+
x2 = tf.constant(x_val2, dtype=tf.float32, name='x2')
490+
mi = tf.minimum(x1, x2)
491+
output = tf.identity(mi, name=_TFOUTPUT)
492+
actual, expected = self._run(output, {x1: x_val1}, {_INPUT: x_val1})
493+
self.assertAllClose(expected, actual)
494+
472495
def test_logicaland(self):
473496
x_val1 = np.array([1, 0, 1, 1], dtype=np.bool).reshape((2, 2))
474497
x_val2 = np.array([0, 1, 1, 1], dtype=np.bool).reshape((2, 2))
@@ -712,7 +735,6 @@ def test_pad(self):
712735

713736
@unittest.skipIf(BACKEND in ["caffe2", "onnxmsrt"], "not supported correctly in caffe2")
714737
def test_randomuniform(self):
715-
# not supported by onnxmsrt or caffe2
716738
shape = tf.constant([2, 3], name="shape")
717739
x_ = tf.random_uniform(shape, name="rand", dtype=tf.float32)
718740
x_ = tf.identity(x_, name="output1")
@@ -722,6 +744,17 @@ def test_randomuniform(self):
722744
# since results are random, compare the shapes only
723745
self.assertAllClose(expected.shape, actual.shape)
724746

747+
@unittest.skip
748+
def test_randomuniform_int(self):
749+
shape = tf.constant([2, 3], name="shape")
750+
x_ = tf.random_uniform(shape, name="rand", dtype=tf.int32, maxval=10)
751+
x_ = tf.identity(x_, name="output1")
752+
x_ = tf.identity(x_, name="output2")
753+
output = tf.identity(x_, name=_TFOUTPUT)
754+
actual, expected = self._run(output, {}, {})
755+
# since results are random, compare the shapes only
756+
self.assertAllClose(expected.shape, actual.shape)
757+
725758
@unittest.skip
726759
def test_argminmax(self):
727760
# TODO: fails on onnxmsrt caffe2

tests/unity.yaml

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
3DBall:
2-
# caffe2: needs RandomNormal
3-
# onnxmsrtnext: fails on missing onnx min/max broadcast
4-
disabled: true
2+
check_only_shape: true
53
url: https://github.com/Unity-Technologies/ml-agents/raw/master/unity-environment/Assets/ML-Agents/Examples/3DBall/TFModels/3DBall.bytes
64
model: 3DBall.bytes
75
input_get: get_random
@@ -57,9 +55,7 @@ Basic:
5755
- value_estimate:0
5856

5957
Bouncer:
60-
# caffe2: needs RandomNormal
61-
# onnxmsrtnext: fails on missing onnx min/max broadcast
62-
disabled: true
58+
check_only_shape: true
6359
url: https://github.com/Unity-Technologies/ml-agents/raw/master/unity-environment/Assets/ML-Agents/Examples/Bouncer/TFModels/Bouncer.bytes
6460
model: Bouncer.bytes
6561
input_get: get_random
@@ -70,9 +66,7 @@ Bouncer:
7066
- value_estimate:0
7167

7268
crawler:
73-
# caffe2: needs RandomNormal
74-
# onnxmsrtnext: fails on missing onnx min/max broadcast
75-
disabled: true
69+
check_only_shape: true
7670
url: https://github.com/Unity-Technologies/ml-agents/raw/master/unity-environment/Assets/ML-Agents/Examples/Crawler/TFModels/crawler.bytes
7771
model: crawler.bytes
7872
input_get: get_random
@@ -134,9 +128,7 @@ PushBlock:
134128
- action:0
135129

136130
Reacher:
137-
# caffe2: needs RandomNormal
138-
# onnxmsrtnext: fails on missing onnx min/max broadcast
139-
disabled: true
131+
check_only_shape: true
140132
url: https://github.com/Unity-Technologies/ml-agents/raw/master/unity-environment/Assets/ML-Agents/Examples/Reacher/TFModels/Reacher.bytes
141133
model: Reacher.bytes
142134
input_get: get_random

tf2onnx/tfonnx.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def tensorflow_to_onnx(graph):
5757
shape = out.get_shape().as_list()
5858
except Exception as ex:
5959
shape = []
60+
dtypes[out.name] = utils.map_tf_dtype(out.dtype)
6061
output_shapes[out.name] = shape
6162

6263
# minimal conversion of attributes
@@ -805,6 +806,32 @@ def spacetodepth_op(ctx, node, name, args):
805806
return nodes
806807

807808

809+
def minmax_op(ctx, node, name, args):
810+
# tensorflow minimum/maximum support broadcast. Onnx <= opset 7 does not.
811+
# inject a add(0) as 'broadcast' operator if needed.
812+
shapeo = ctx.get_shape(node.output[0])
813+
needs_broadcast_op = []
814+
for i, name in enumerate(node.input):
815+
if ctx.get_shape(name) != shapeo:
816+
needs_broadcast_op.append(i)
817+
if needs_broadcast_op:
818+
new_nodes = []
819+
for i in needs_broadcast_op:
820+
input_node = node.inputs[i]
821+
dtype = ctx.dtypes[node.input[i]]
822+
zero_name = utils.make_name(input_node.name)
823+
zero_node = ctx.make_const(zero_name, "Const", np.zeros(shapeo, dtype=utils.ONNX_TO_NUMPY_DTYPE[dtype]))
824+
op_name = utils.make_name(input_node.name)
825+
output_name = op_name + ":0"
826+
add_node = Node(helper.make_node("Add", [input_node.output[0], zero_name],
827+
[output_name], name=op_name), ctx)
828+
node.input[i] = output_name
829+
new_nodes.append(add_node)
830+
new_nodes.append(node)
831+
return new_nodes
832+
return node
833+
834+
808835
# pylint: enable=W0613,C0111,W0612
809836

810837
# map tensorflow ops to onnx ops. The format below is
@@ -847,12 +874,12 @@ def spacetodepth_op(ctx, node, name, args):
847874
"LogicalOr": (broadcast_op, ["Or"]),
848875
"Max": (reduce_op, ["ReduceMax"]),
849876
"MatMul": (direct_op, ["MatMul"]),
850-
"Maximum": (direct_op, ["Max"]),
877+
"Maximum": (minmax_op, ["Max"]),
851878
"MaxPool": (pool_op, ["MaxPool"]),
852879
"MaxPoolV2": (pool_op, ["MaxPool"]),
853880
"Mean": (reduce_op, ["ReduceMean"]),
854881
"Min": (reduce_op, ["ReduceMin"]),
855-
"Minimum": (direct_op, ["Min"]),
882+
"Minimum": (minmax_op, ["Min"]),
856883
"Mul": (broadcast_op, []),
857884
"Neg": (direct_op, []),
858885
"NoOp": (no_op, []),

0 commit comments

Comments
 (0)