Skip to content

Commit e483994

Browse files
Merge pull request #1233 from onnx/tom/opset13updates
Update split op for opset13 and add test cases
2 parents 1e71c92 + 2d7505c commit e483994

File tree

4 files changed

+81
-2
lines changed

4 files changed

+81
-2
lines changed

tests/test_backend.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1366,6 +1366,15 @@ def func(x):
13661366
return tf.identity(x_, name=_TFOUTPUT)
13671367
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
13681368

1369+
@check_opset_min_version(13, "Split")
1370+
def test_split_nonconst(self):
1371+
x_val = np.linspace(1.0, 5 * 30.0, 5 * 30).astype(np.float32).reshape((5, 30))
1372+
y_val = np.array([4, 15, 11], np.int32)
1373+
def func(x, y):
1374+
x_, _, _ = tf.split(x, y, 1)
1375+
return tf.identity(x_, name=_TFOUTPUT)
1376+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val})
1377+
13691378
def test_split_with_more_outputs(self):
13701379
x_val = np.linspace(1.0, 5 * 30.0, 5 * 30).astype(np.float32).reshape((5, 30))
13711380
def func(x):
@@ -1387,6 +1396,24 @@ def func(x):
13871396
return tf.identity(x_, name=_TFOUTPUT)
13881397
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
13891398

1399+
@check_opset_min_version(13, "ReduceSum")
1400+
def test_reducesum_nonconst_axis(self):
1401+
x_val = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32).reshape((2, 1, 2))
1402+
y_val = np.array([1, 2], dtype=np.int32)
1403+
def func(x, y):
1404+
x_ = tf.reduce_sum(x, axis=y)
1405+
return tf.identity(x_, name=_TFOUTPUT)
1406+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val})
1407+
1408+
@check_opset_min_version(13, "ReduceSum")
1409+
def test_reducesum_empty_axis(self):
1410+
x_val = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32).reshape((2, 1, 2))
1411+
y_val = np.array([], dtype=np.int32)
1412+
def func(x, y):
1413+
x_ = tf.reduce_sum(x, axis=y)
1414+
return tf.identity(x_, name=_TFOUTPUT)
1415+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val})
1416+
13901417
@check_opset_min_version(9, "OneHot")
13911418
def test_segment_sum_data_vector(self):
13921419
segs_val = np.array([0, 0, 0, 1, 2, 2, 3, 3], dtype=np.int32)
@@ -2866,6 +2893,16 @@ def func(x):
28662893
return tf.identity(res, name=_TFOUTPUT), tf.identity(res1, name=_TFOUTPUT1)
28672894
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: input_val})
28682895

2896+
@check_opset_min_version(11, "ReduceSum")
2897+
@check_tf_min_version("1.15")
2898+
def test_reduce_any_empty_axis(self):
2899+
input_val = np.random.randint(0, 2, (10, 20)).astype(np.bool)
2900+
def func(x):
2901+
res = tf.reduce_any(input_tensor=x, keepdims=False)
2902+
res1 = tf.reduce_any(input_tensor=x, axis=[], keepdims=False)
2903+
return tf.identity(res, name=_TFOUTPUT), tf.identity(res1, name=_TFOUTPUT1)
2904+
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: input_val})
2905+
28692906
@check_opset_min_version(7, "fill")
28702907
def test_zeros_like(self):
28712908
input_x = np.random.random_sample([10, 20]).astype(np.float32)
@@ -3289,6 +3326,20 @@ def func(x):
32893326
return tf.identity(y, name=_TFOUTPUT)
32903327
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
32913328

3329+
def test_softmax(self):
3330+
x_val = np.arange(0, 24, dtype=np.float32).reshape([3, 1, 8])
3331+
def func(x):
3332+
y = tf.nn.softmax(x)
3333+
return tf.identity(y, name=_TFOUTPUT)
3334+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
3335+
3336+
def test_log_softmax(self):
3337+
x_val = np.arange(0, 24, dtype=np.float32).reshape([3, 1, 8])
3338+
def func(x):
3339+
y = tf.nn.log_softmax(x)
3340+
return tf.identity(y, name=_TFOUTPUT)
3341+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
3342+
32923343
# test for gemm pattern0: alpha*A*B + beta*C
32933344
def test_gemm_pattern0(self):
32943345
max_number = 10

tf2onnx/onnx_opset/math.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,11 @@ def version_1(cls, ctx, node, **kwargs):
201201
def version_11(cls, ctx, node, **kwargs):
202202
cls.version_1(ctx, node, **kwargs)
203203

204+
@classmethod
205+
def version_13(cls, ctx, node, **kwargs):
206+
# Default axis is now -1.
207+
pass
208+
204209

205210
@tf_op("Square")
206211
class Square:

tf2onnx/onnx_opset/rnn.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,9 @@ def make_sigmoid(i, w, b):
9494
raise RuntimeError("shape of W of LSTMBlockCell {} should be times of 4".format(node.name))
9595
merged_output_node = ctx.make_node("Add", [xh_w_node.output[0], b])
9696
w_last_dim = int(w_shape[1] / 4)
97-
split = [w_last_dim] * 4
9897
split_output_node = ctx.make_node(
9998
"Split", [merged_output_node.output[0]],
100-
attr={"axis": 1, "split": split},
99+
attr={"axis": 1},
101100
output_count=4
102101
)
103102
i, ci, f, o = split_output_node.output

tf2onnx/onnx_opset/tensor.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -625,6 +625,10 @@ def version_11(cls, ctx, node, **kwargs):
625625
# no change
626626
cls.version_1(ctx, node, **kwargs)
627627

628+
@classmethod
629+
def version_13(cls, ctx, node, **kwargs):
630+
# Default axis is not -1 but doesn't matter since we always set it.
631+
cls.version_1(ctx, node, **kwargs)
628632

629633
@tf_op("SplitV")
630634
class SplitV:
@@ -652,6 +656,26 @@ def version_1(cls, ctx, node, **kwargs):
652656
def version_2(cls, ctx, node, **kwargs):
653657
cls.version_1(ctx, node, **kwargs)
654658

659+
@classmethod
660+
def version_13(cls, ctx, node, **kwargs):
661+
# Split now supports dynamic split lengths
662+
if node.inputs[1].is_const():
663+
# Call version 1 to deal with -1 cases
664+
cls.version_1(ctx, node, **kwargs)
665+
# Convert attr to input
666+
split_val = node.get_attr_value("split")
667+
split_const = ctx.make_const(utils.make_name("split"), np.array(split_val, np.int64))
668+
ctx.replace_inputs(node, [node.input[0], split_const.output[0]])
669+
del node.attr["split"]
670+
else:
671+
# Technically incorrect if any of the splits are -1
672+
node.type = "Split"
673+
split_dims = node.inputs[2].get_tensor_value()
674+
ctx.remove_input(node, node.input[2], 2)
675+
node.set_attr("axis", split_dims)
676+
if ctx.get_dtype(node.input[1]) != TensorProto.INT64:
677+
ctx.insert_new_node_on_input(node, "Cast", node.input[1], to=TensorProto.INT64)
678+
655679

656680
@tf_op("ExpandDims")
657681
class ExpandDims:

0 commit comments

Comments
 (0)