Skip to content

Commit 4cfc025

Browse files
Improve perf of SegmentSum (#1463)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent f14d5a7 commit 4cfc025

File tree

7 files changed

+105
-98
lines changed

7 files changed

+105
-98
lines changed

tests/backend_test_base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,8 @@ def get_shape(info):
250250
return None
251251
return [d.dim_value if d.HasField('dim_value') else -1 for d in info.type.tensor_type.shape.dim]
252252
for info in model_shapes.graph.value_info:
253+
if info.name == "":
254+
continue
253255
onnx_shape = get_shape(info)
254256
tf2onnx_shape = graph.get_shape(info.name)
255257
if onnx_shape is None:

tests/test_backend.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1544,7 +1544,7 @@ def func(x, y):
15441544
return tf.identity(x_, name=_TFOUTPUT)
15451545
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val})
15461546

1547-
@check_opset_min_version(9, "OneHot")
1547+
@check_opset_min_version(11, "ScatterND")
15481548
def test_segment_sum_data_vector(self):
15491549
segs_val = np.array([0, 0, 0, 1, 2, 2, 3, 3], dtype=np.int32)
15501550
data_val = np.array([5, 1, 7, 2, 3, 4, 1, 3], dtype=np.float32)
@@ -1553,7 +1553,7 @@ def func(data, segments):
15531553
return tf.identity(x_, name=_TFOUTPUT)
15541554
self._run_test_case(func, [_OUTPUT], {_INPUT: data_val, _INPUT1: segs_val})
15551555

1556-
@check_opset_min_version(11, "Pad")
1556+
@check_opset_min_version(11, "ScatterND")
15571557
def test_segment_sum_unknown_rank(self):
15581558
segs_val = np.array([0, 0, 0, 1, 2, 2, 3, 3], dtype=np.int32)
15591559
data_val = np.arange(8 * 2 * 3, dtype=np.float32).reshape([8, 2, 3])
@@ -1568,7 +1568,7 @@ def func(data, segments, data_shape, shape_pad):
15681568
self._run_test_case(func, [_OUTPUT],
15691569
{_INPUT: data_val, _INPUT1: segs_val, _INPUT2: data_shape_val, _INPUT3: shape_pad_val})
15701570

1571-
@check_opset_min_version(9, "OneHot")
1571+
@check_opset_min_version(11, "ScatterND")
15721572
def test_segment_ops_data_tensor(self):
15731573
for tf_op in [tf.math.segment_sum, tf.math.segment_prod, tf.math.segment_min, tf.math.segment_max]:
15741574
segs_val = np.array([0, 0, 0, 1, 2, 2, 3, 3], dtype=np.int32)
@@ -1578,7 +1578,7 @@ def func(data, segments):
15781578
return tf.identity(x_, name=_TFOUTPUT)
15791579
self._run_test_case(func, [_OUTPUT], {_INPUT: data_val, _INPUT1: segs_val})
15801580

1581-
@check_opset_min_version(11, "Pad")
1581+
@check_opset_min_version(11, "ScatterND")
15821582
@skip_tflite("unknown rank")
15831583
def test_segment_mean_unknown_rank(self):
15841584
segs_val = np.array([0, 0, 0, 1, 2, 2, 3, 3], dtype=np.int32)
@@ -1594,7 +1594,7 @@ def func(data, segments, data_shape, shape_pad):
15941594
self._run_test_case(func, [_OUTPUT],
15951595
{_INPUT: data_val, _INPUT1: segs_val, _INPUT2: data_shape_val, _INPUT3: shape_pad_val})
15961596

1597-
@check_opset_min_version(9, "OneHot")
1597+
@check_opset_min_version(11, "ScatterND")
15981598
def test_sparse_segment_sum(self):
15991599
data_val = np.arange(8 * 2 * 3, dtype=np.float32).reshape([8, 2, 3])
16001600
indices_val = np.array([2, 0, 1, 3, 5, 4, 3, 5, 5], dtype=np.int32)
@@ -1604,7 +1604,7 @@ def func(data, indices, segments):
16041604
return tf.identity(x_, name=_TFOUTPUT)
16051605
self._run_test_case(func, [_OUTPUT], {_INPUT: data_val, _INPUT1: indices_val, _INPUT2: segs_val})
16061606

1607-
@check_opset_min_version(9, "OneHot")
1607+
@check_opset_min_version(11, "ScatterND")
16081608
def test_sparse_segment_mean(self):
16091609
data_val = np.arange(8 * 2 * 3, dtype=np.float32).reshape([8, 2, 3])
16101610
indices_val = np.array([2, 0, 1, 3, 5, 4, 3, 5, 5], dtype=np.int32)
@@ -1614,7 +1614,7 @@ def func(data, indices, segments):
16141614
return tf.identity(x_, name=_TFOUTPUT)
16151615
self._run_test_case(func, [_OUTPUT], {_INPUT: data_val, _INPUT1: indices_val, _INPUT2: segs_val})
16161616

1617-
@check_opset_min_version(9, "OneHot")
1617+
@check_opset_min_version(11, "ScatterND")
16181618
def test_sparse_segment_sqrtn(self):
16191619
data_val = np.arange(8 * 2 * 3, dtype=np.float32).reshape([8, 2, 3])
16201620
indices_val = np.array([2, 0, 1, 3, 5, 4, 3, 5, 5], dtype=np.int32)
@@ -1624,7 +1624,7 @@ def func(data, indices, segments):
16241624
return tf.identity(x_, name=_TFOUTPUT)
16251625
self._run_test_case(func, [_OUTPUT], {_INPUT: data_val, _INPUT1: indices_val, _INPUT2: segs_val})
16261626

1627-
@check_opset_min_version(9, "OneHot")
1627+
@check_opset_min_version(11, "ScatterND")
16281628
def test_sparse_segment_ops_with_num_segments(self):
16291629
for tf_op in [tf.sparse.segment_sum, tf.sparse.segment_mean, tf.sparse.segment_sqrt_n]:
16301630
data_val = np.arange(8 * 2 * 3, dtype=np.float32).reshape([8, 2, 3])
@@ -1635,7 +1635,7 @@ def func(data, indices, segments):
16351635
return tf.identity(x_, name=_TFOUTPUT)
16361636
self._run_test_case(func, [_OUTPUT], {_INPUT: data_val, _INPUT1: indices_val, _INPUT2: segs_val})
16371637

1638-
@check_opset_min_version(9, "OneHot")
1638+
@check_opset_min_version(11, "ScatterND")
16391639
@check_tf_min_version("2.3", "needs tf 2.3")
16401640
def test_unsorted_segment_ops(self):
16411641
tf_ops = [
@@ -1654,7 +1654,7 @@ def func(data, segments):
16541654
return tf.identity(x_, name=_TFOUTPUT)
16551655
self._run_test_case(func, [_OUTPUT], {_INPUT: data_val, _INPUT1: segs_val})
16561656

1657-
@check_opset_min_version(9, "OneHot")
1657+
@check_opset_min_version(11, "ScatterND")
16581658
@check_tf_min_version("2.3", "num_segments can be int64 in tf 2.3")
16591659
def test_segment_op_types(self):
16601660
data_dtypes = [np.int32, np.float32]

tests/test_cond.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ def case_graph():
233233

234234
@check_tf_min_version("1.8", "shape inference for Reshape op screws up")
235235
@check_opset_min_version(9, "ConstantOfShape")
236+
@allow_missing_shapes("ONNX shape inference still determines if/else shape for unknown reason")
236237
def test_cond_with_different_output_shape(self):
237238
input_shape = (10, 5, 20)
238239
def func(inputs, shape):

tf2onnx/graph_builder.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def make_slice(self, kwargs, name=None, shapes=None, dtypes=None, return_node=Fa
8282
return node
8383
return node.output[0]
8484

85-
def make_reduce_sum(self, kwargs, name=None, shapes=None, dtypes=None):
85+
def make_reduce_sum(self, kwargs, name=None, shapes=None, dtypes=None, op_name_scope=None):
8686
"""
8787
ReduceSum changes its schema at opset 13: it treats some axes as dynamic input
8888
kwargs: key could be ["data", "axes", "keepdims", "noop_with_empty_axes", "outputs"].
@@ -115,7 +115,8 @@ def make_reduce_sum(self, kwargs, name=None, shapes=None, dtypes=None):
115115
attr = new_attr
116116

117117
return self.graph.make_node(op_type="ReduceSum", inputs=inputs, attr=attr, name=name,
118-
outputs=outputs, shapes=shapes, dtypes=dtypes).output[0]
118+
outputs=outputs, shapes=shapes, dtypes=dtypes,
119+
op_name_scope=op_name_scope).output[0]
119120

120121
def make_squeeze(self, kwargs, name=None, shapes=None, dtypes=None, return_node=False, op_name_scope=None):
121122
"""

tf2onnx/onnx_opset/controlflow.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ def version_9(cls, ctx, node, **kwargs):
217217
node.output[0], name=utils.make_name("where_op_added"))
218218
ctx.copy_shape(node.output[0], transpose_node.output[0])
219219
ctx.copy_dtype(node.output[0], transpose_node.output[0])
220+
ctx.update_node_shape_dtype(node, override=True)
220221

221222

222223
@tf_op(["StatelessIf"])

0 commit comments

Comments
 (0)