Skip to content

Commit 7d04411

Browse files
Added support for SegmentSum
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 1bc3079 commit 7d04411

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

tests/test_backend.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1307,6 +1307,14 @@ def func(x):
13071307
return tf.identity(x_, name=_TFOUTPUT)
13081308
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
13091309

1310+
def test_segment_sum(self):
1311+
segs_val = np.array([0, 0, 0, 1, 2, 2, 3, 3], dtype=np.int32)
1312+
data_val = np.array([5, 1, 7, 2, 3, 4, 1, 3], dtype=np.float32)
1313+
def func(data, segments):
1314+
x_ = tf.math.segment_sum(data, segments)
1315+
return tf.identity(x_, name=_TFOUTPUT)
1316+
self._run_test_case(func, [_OUTPUT], {_INPUT: data_val, _INPUT1: segs_val})
1317+
13101318
@check_onnxruntime_incompatibility("Sqrt")
13111319
def test_sqrt(self):
13121320
x_val = np.array([4.0, 16.0, 4.0, 1.6], dtype=np.float32).reshape((2, 2))

tf2onnx/onnx_opset/reduction.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,3 +130,27 @@ class AddN():
130130
@classmethod
131131
def version_6(cls, ctx, node, **kwargs):
132132
node.type = "Sum"
133+
134+
135+
@tf_op("SegmentSum")
136+
class SegmentSum():
137+
@classmethod
138+
def version_9(cls, ctx, node, **kwargs):
139+
data_inp = node.input[0]
140+
segment_inp = node.input[1]
141+
data_shape = ctx.get_shape(data_inp)
142+
utils.make_sure(data_shape is not None, "Segment ops require input rank to be known")
143+
data_np_dtype = utils.map_onnx_to_numpy_type(ctx.get_dtype(data_inp))
144+
seg_np_dtype = utils.map_onnx_to_numpy_type(ctx.get_dtype(segment_inp))
145+
max_segment = ctx.make_node("ReduceMax", [segment_inp], attr={'axes': [0], 'keepdims': 0})
146+
one_const = ctx.make_const(utils.make_name("const_one"), np.array(1, dtype=seg_np_dtype))
147+
num_segments = ctx.make_node("Add", [max_segment.output[0], one_const.output[0]])
148+
onehot_values = ctx.make_const(utils.make_name("onehot_values"), np.array([0, 1], dtype=data_np_dtype))
149+
one_hot_node = ctx.make_node("OneHot", [segment_inp, num_segments.output[0], onehot_values.output[0]], attr={'axis': 0})
150+
mul_node = ctx.make_node("Mul", [data_inp, one_hot_node.output[0]])
151+
152+
shapes = node.output_shapes
153+
dtypes = node.output_dtypes
154+
ctx.remove_node(node.name)
155+
sum_node = ctx.make_node("ReduceSum", [mul_node.output[0]], attr={'axes': [1], 'keepdims': 0},
156+
name=node.name, outputs=node.output, shapes=shapes, dtypes=dtypes)

0 commit comments

Comments
 (0)