Skip to content

Commit 6ab8cfb

Browse files
Added support for SegmentProd, SegmentMax, SegmentMin
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent c3dc052 commit 6ab8cfb

File tree

2 files changed

+46
-15
lines changed

2 files changed

+46
-15
lines changed

tests/test_backend.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1307,6 +1307,7 @@ def func(x):
13071307
return tf.identity(x_, name=_TFOUTPUT)
13081308
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
13091309

1310+
@check_opset_min_version(9, "OneHot")
13101311
def test_segment_sum_data_vector(self):
13111312
segs_val = np.array([0, 0, 0, 1, 2, 2, 3, 3], dtype=np.int32)
13121313
data_val = np.array([5, 1, 7, 2, 3, 4, 1, 3], dtype=np.float32)
@@ -1315,13 +1316,15 @@ def func(data, segments):
13151316
return tf.identity(x_, name=_TFOUTPUT)
13161317
self._run_test_case(func, [_OUTPUT], {_INPUT: data_val, _INPUT1: segs_val})
13171318

1318-
def test_segment_sum_data_tensor(self):
1319-
segs_val = np.array([0, 0, 0, 1, 2, 2, 3, 3], dtype=np.int32)
1320-
data_val = np.arange(8 * 2 * 3, dtype=np.float32).reshape([8, 2, 3])
1321-
def func(data, segments):
1322-
x_ = tf.math.segment_sum(data, segments)
1323-
return tf.identity(x_, name=_TFOUTPUT)
1324-
self._run_test_case(func, [_OUTPUT], {_INPUT: data_val, _INPUT1: segs_val})
1319+
@check_opset_min_version(9, "OneHot")
1320+
def test_segment_ops_data_tensor(self):
1321+
for tf_op in [tf.math.segment_sum, tf.math.segment_prod, tf.math.segment_min, tf.math.segment_max]:
1322+
segs_val = np.array([0, 0, 0, 1, 2, 2, 3, 3], dtype=np.int32)
1323+
data_val = np.arange(8 * 2 * 3, dtype=np.float32).reshape([8, 2, 3])
1324+
def func(data, segments):
1325+
x_ = tf_op(data, segments)
1326+
return tf.identity(x_, name=_TFOUTPUT)
1327+
self._run_test_case(func, [_OUTPUT], {_INPUT: data_val, _INPUT1: segs_val})
13251328

13261329
@check_onnxruntime_incompatibility("Sqrt")
13271330
def test_sqrt(self):

tf2onnx/onnx_opset/reduction.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def version_6(cls, ctx, node, **kwargs):
132132
node.type = "Sum"
133133

134134

135-
@tf_op("SegmentSum")
135+
@tf_op(["SegmentSum", "SegmentProd", "SegmentMax", "SegmentMin"])
136136
class SegmentSum():
137137
@classmethod
138138
def version_9(cls, ctx, node, **kwargs):
@@ -143,20 +143,48 @@ def version_9(cls, ctx, node, **kwargs):
143143
data_rank = len(data_shape)
144144
data_np_dtype = utils.map_onnx_to_numpy_type(ctx.get_dtype(data_inp))
145145
seg_np_dtype = utils.map_onnx_to_numpy_type(ctx.get_dtype(segment_inp))
146+
data_is_float = np.dtype(data_np_dtype).kind == 'f'
147+
data_is_int = np.dtype(data_np_dtype).kind == 'i'
148+
utils.make_sure(data_is_float or data_is_int, "dtype for Segment ops must be float or int")
149+
150+
if node.type == "SegmentSum":
151+
onnx_op = "ReduceSum"
152+
identity_value = np.array(0, dtype=data_np_dtype)
153+
elif node.type == "SegmentProd":
154+
onnx_op = "ReduceProd"
155+
identity_value = np.array(1, dtype=data_np_dtype)
156+
elif node.type == "SegmentMax":
157+
onnx_op = "ReduceMax"
158+
if data_is_float:
159+
identity_value = np.array('-inf', dtype=data_np_dtype)
160+
else:
161+
identity_value = np.iinfo(data_np_dtype).min
162+
elif node.type == "SegmentMin":
163+
onnx_op = "ReduceMin"
164+
if data_is_float:
165+
identity_value = np.array('inf', dtype=data_np_dtype)
166+
else:
167+
identity_value = np.iinfo(data_np_dtype).max
168+
146169
max_segment = ctx.make_node("ReduceMax", [segment_inp], attr={'axes': [0], 'keepdims': 0})
147170
one_const = ctx.make_const(utils.make_name("const_one"), np.array(1, dtype=seg_np_dtype))
171+
identity_const = ctx.make_const(utils.make_name("const_identity"), identity_value)
148172
num_segments = ctx.make_node("Add", [max_segment.output[0], one_const.output[0]])
149-
onehot_values = ctx.make_const(utils.make_name("onehot_values"), np.array([0, 1], dtype=data_np_dtype))
150-
one_hot_node = ctx.make_node("OneHot", [segment_inp, num_segments.output[0], onehot_values.output[0]], attr={'axis': 0})
151-
one_hot_unsqueeze = one_hot_node
173+
# ORT doesn't support bool for OneHot so we use float32 and cast to bool
174+
onehot_values = ctx.make_const(utils.make_name("onehot_values"), np.array([0, 1], dtype=np.float32))
175+
one_hot_node = ctx.make_node("OneHot", [segment_inp, num_segments.output[0], onehot_values.output[0]],
176+
attr={'axis': 0})
177+
one_hot_bool = ctx.make_node("Cast", [one_hot_node.output[0]], attr={"to": onnx_pb.TensorProto.BOOL})
178+
one_hot_unsqueeze = one_hot_bool
179+
152180
if data_rank > 1:
153181
new_dims = list(range(2, 2 + data_rank - 1))
154-
one_hot_unsqueeze = ctx.make_node("Unsqueeze", [one_hot_node.output[0]], attr={'axes': new_dims})
182+
one_hot_unsqueeze = ctx.make_node("Unsqueeze", [one_hot_bool.output[0]], attr={'axes': new_dims})
155183

156-
mul_node = ctx.make_node("Mul", [data_inp, one_hot_unsqueeze.output[0]])
184+
mul_node = ctx.make_node("Where", [one_hot_unsqueeze.output[0], data_inp, identity_const.output[0]])
157185

158186
shapes = node.output_shapes
159187
dtypes = node.output_dtypes
160188
ctx.remove_node(node.name)
161-
sum_node = ctx.make_node("ReduceSum", [mul_node.output[0]], attr={'axes': [1], 'keepdims': 0},
162-
name=node.name, outputs=node.output, shapes=shapes, dtypes=dtypes)
189+
ctx.make_node(onnx_op, [mul_node.output[0]], attr={'axes': [1], 'keepdims': 0},
190+
name=node.name, outputs=node.output, shapes=shapes, dtypes=dtypes)

0 commit comments

Comments
 (0)