Skip to content

Commit a1e0e09

Browse files
authored
Merge pull request #1161 from onnx/tom/SegmentSum
Added conversions for Segment ops
2 parents aec5854 + 6ab8cfb commit a1e0e09

File tree

2 files changed

+77
-0
lines changed

2 files changed

+77
-0
lines changed

tests/test_backend.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1307,6 +1307,25 @@ def func(x):
13071307
return tf.identity(x_, name=_TFOUTPUT)
13081308
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
13091309

1310+
@check_opset_min_version(9, "OneHot")
1311+
def test_segment_sum_data_vector(self):
1312+
segs_val = np.array([0, 0, 0, 1, 2, 2, 3, 3], dtype=np.int32)
1313+
data_val = np.array([5, 1, 7, 2, 3, 4, 1, 3], dtype=np.float32)
1314+
def func(data, segments):
1315+
x_ = tf.math.segment_sum(data, segments)
1316+
return tf.identity(x_, name=_TFOUTPUT)
1317+
self._run_test_case(func, [_OUTPUT], {_INPUT: data_val, _INPUT1: segs_val})
1318+
1319+
@check_opset_min_version(9, "OneHot")
1320+
def test_segment_ops_data_tensor(self):
1321+
for tf_op in [tf.math.segment_sum, tf.math.segment_prod, tf.math.segment_min, tf.math.segment_max]:
1322+
segs_val = np.array([0, 0, 0, 1, 2, 2, 3, 3], dtype=np.int32)
1323+
data_val = np.arange(8 * 2 * 3, dtype=np.float32).reshape([8, 2, 3])
1324+
def func(data, segments):
1325+
x_ = tf_op(data, segments)
1326+
return tf.identity(x_, name=_TFOUTPUT)
1327+
self._run_test_case(func, [_OUTPUT], {_INPUT: data_val, _INPUT1: segs_val})
1328+
13101329
@check_onnxruntime_incompatibility("Sqrt")
13111330
def test_sqrt(self):
13121331
x_val = np.array([4.0, 16.0, 4.0, 1.6], dtype=np.float32).reshape((2, 2))

tf2onnx/onnx_opset/reduction.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,3 +130,61 @@ class AddN():
130130
@classmethod
131131
def version_6(cls, ctx, node, **kwargs):
132132
node.type = "Sum"
133+
134+
135+
@tf_op(["SegmentSum", "SegmentProd", "SegmentMax", "SegmentMin"])
136+
class SegmentSum():
137+
@classmethod
138+
def version_9(cls, ctx, node, **kwargs):
139+
data_inp = node.input[0]
140+
segment_inp = node.input[1]
141+
data_shape = ctx.get_shape(data_inp)
142+
utils.make_sure(data_shape is not None, "Segment ops require input rank to be known")
143+
data_rank = len(data_shape)
144+
data_np_dtype = utils.map_onnx_to_numpy_type(ctx.get_dtype(data_inp))
145+
seg_np_dtype = utils.map_onnx_to_numpy_type(ctx.get_dtype(segment_inp))
146+
data_is_float = np.dtype(data_np_dtype).kind == 'f'
147+
data_is_int = np.dtype(data_np_dtype).kind == 'i'
148+
utils.make_sure(data_is_float or data_is_int, "dtype for Segment ops must be float or int")
149+
150+
if node.type == "SegmentSum":
151+
onnx_op = "ReduceSum"
152+
identity_value = np.array(0, dtype=data_np_dtype)
153+
elif node.type == "SegmentProd":
154+
onnx_op = "ReduceProd"
155+
identity_value = np.array(1, dtype=data_np_dtype)
156+
elif node.type == "SegmentMax":
157+
onnx_op = "ReduceMax"
158+
if data_is_float:
159+
identity_value = np.array('-inf', dtype=data_np_dtype)
160+
else:
161+
identity_value = np.iinfo(data_np_dtype).min
162+
elif node.type == "SegmentMin":
163+
onnx_op = "ReduceMin"
164+
if data_is_float:
165+
identity_value = np.array('inf', dtype=data_np_dtype)
166+
else:
167+
identity_value = np.iinfo(data_np_dtype).max
168+
169+
max_segment = ctx.make_node("ReduceMax", [segment_inp], attr={'axes': [0], 'keepdims': 0})
170+
one_const = ctx.make_const(utils.make_name("const_one"), np.array(1, dtype=seg_np_dtype))
171+
identity_const = ctx.make_const(utils.make_name("const_identity"), identity_value)
172+
num_segments = ctx.make_node("Add", [max_segment.output[0], one_const.output[0]])
173+
# ORT doesn't support bool for OneHot so we use float32 and cast to bool
174+
onehot_values = ctx.make_const(utils.make_name("onehot_values"), np.array([0, 1], dtype=np.float32))
175+
one_hot_node = ctx.make_node("OneHot", [segment_inp, num_segments.output[0], onehot_values.output[0]],
176+
attr={'axis': 0})
177+
one_hot_bool = ctx.make_node("Cast", [one_hot_node.output[0]], attr={"to": onnx_pb.TensorProto.BOOL})
178+
one_hot_unsqueeze = one_hot_bool
179+
180+
if data_rank > 1:
181+
new_dims = list(range(2, 2 + data_rank - 1))
182+
one_hot_unsqueeze = ctx.make_node("Unsqueeze", [one_hot_bool.output[0]], attr={'axes': new_dims})
183+
184+
mul_node = ctx.make_node("Where", [one_hot_unsqueeze.output[0], data_inp, identity_const.output[0]])
185+
186+
shapes = node.output_shapes
187+
dtypes = node.output_dtypes
188+
ctx.remove_node(node.name)
189+
ctx.make_node(onnx_op, [mul_node.output[0]], attr={'axes': [1], 'keepdims': 0},
190+
name=node.name, outputs=node.output, shapes=shapes, dtypes=dtypes)

0 commit comments

Comments
 (0)