Skip to content

Commit c3dc052

Browse files
Made SegmentSum support higher-dimensional data
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 7d04411 commit c3dc052

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

tests/test_backend.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1307,14 +1307,22 @@ def func(x):
13071307
return tf.identity(x_, name=_TFOUTPUT)
13081308
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
13091309

1310-
def test_segment_sum(self):
1310+
def test_segment_sum_data_vector(self):
13111311
segs_val = np.array([0, 0, 0, 1, 2, 2, 3, 3], dtype=np.int32)
13121312
data_val = np.array([5, 1, 7, 2, 3, 4, 1, 3], dtype=np.float32)
13131313
def func(data, segments):
13141314
x_ = tf.math.segment_sum(data, segments)
13151315
return tf.identity(x_, name=_TFOUTPUT)
13161316
self._run_test_case(func, [_OUTPUT], {_INPUT: data_val, _INPUT1: segs_val})
13171317

1318+
def test_segment_sum_data_tensor(self):
1319+
segs_val = np.array([0, 0, 0, 1, 2, 2, 3, 3], dtype=np.int32)
1320+
data_val = np.arange(8 * 2 * 3, dtype=np.float32).reshape([8, 2, 3])
1321+
def func(data, segments):
1322+
x_ = tf.math.segment_sum(data, segments)
1323+
return tf.identity(x_, name=_TFOUTPUT)
1324+
self._run_test_case(func, [_OUTPUT], {_INPUT: data_val, _INPUT1: segs_val})
1325+
13181326
@check_onnxruntime_incompatibility("Sqrt")
13191327
def test_sqrt(self):
13201328
x_val = np.array([4.0, 16.0, 4.0, 1.6], dtype=np.float32).reshape((2, 2))

tf2onnx/onnx_opset/reduction.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,14 +140,20 @@ def version_9(cls, ctx, node, **kwargs):
140140
segment_inp = node.input[1]
141141
data_shape = ctx.get_shape(data_inp)
142142
utils.make_sure(data_shape is not None, "Segment ops require input rank to be known")
143+
data_rank = len(data_shape)
143144
data_np_dtype = utils.map_onnx_to_numpy_type(ctx.get_dtype(data_inp))
144145
seg_np_dtype = utils.map_onnx_to_numpy_type(ctx.get_dtype(segment_inp))
145146
max_segment = ctx.make_node("ReduceMax", [segment_inp], attr={'axes': [0], 'keepdims': 0})
146147
one_const = ctx.make_const(utils.make_name("const_one"), np.array(1, dtype=seg_np_dtype))
147148
num_segments = ctx.make_node("Add", [max_segment.output[0], one_const.output[0]])
148149
onehot_values = ctx.make_const(utils.make_name("onehot_values"), np.array([0, 1], dtype=data_np_dtype))
149150
one_hot_node = ctx.make_node("OneHot", [segment_inp, num_segments.output[0], onehot_values.output[0]], attr={'axis': 0})
150-
mul_node = ctx.make_node("Mul", [data_inp, one_hot_node.output[0]])
151+
one_hot_unsqueeze = one_hot_node
152+
if data_rank > 1:
153+
new_dims = list(range(2, 2 + data_rank - 1))
154+
one_hot_unsqueeze = ctx.make_node("Unsqueeze", [one_hot_node.output[0]], attr={'axes': new_dims})
155+
156+
mul_node = ctx.make_node("Mul", [data_inp, one_hot_unsqueeze.output[0]])
151157

152158
shapes = node.output_shapes
153159
dtypes = node.output_dtypes

0 commit comments

Comments
 (0)