Skip to content

Commit f13d2d7

Browse files
Merge pull request #1159 from onnx/tom/DynamicPartition
Added support for DynamicPartition
2 parents 669422c + 64de3ae commit f13d2d7

File tree

2 files changed

+52
-0
lines changed

2 files changed

+52
-0
lines changed

tests/test_backend.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3258,6 +3258,30 @@ def func(x):
32583258
#self._run_test_case([_OUTPUT, _OUTPUT1], {_INPUT: x_val})
32593259
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
32603260

3261+
@check_opset_min_version(9, "Compress")
3262+
def test_dynamic_partition_both_vector(self):
3263+
data_val = np.array([1, 2, 3, 4, 5, 6, 7, 8], dtype=np.float32)
3264+
part_val = np.array([0, 0, 1, 1, 0, 2, 1, 0], dtype=np.int32)
3265+
def func(data, partitions):
3266+
p1, p2, p3 = tf.dynamic_partition(data, partitions, num_partitions=3)
3267+
p1_ = tf.identity(p1, name=_TFOUTPUT)
3268+
p2_ = tf.identity(p2, name=_TFOUTPUT1)
3269+
p3_ = tf.identity(p3, name=_TFOUTPUT2)
3270+
return p1_, p2_, p3_
3271+
self._run_test_case(func, [_OUTPUT, _OUTPUT1, _OUTPUT2], {_INPUT: data_val, _INPUT1: part_val})
3272+
3273+
@check_opset_min_version(9, "Compress")
3274+
def test_dynamic_partition_data_tensor(self):
3275+
data_val = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]], dtype=np.float32)
3276+
part_val = np.array([0, 2, 1, 0, 1], dtype=np.int32)
3277+
def func(data, partitions):
3278+
p1, p2, p3 = tf.dynamic_partition(data, partitions, num_partitions=3)
3279+
p1_ = tf.identity(p1, name=_TFOUTPUT)
3280+
p2_ = tf.identity(p2, name=_TFOUTPUT1)
3281+
p3_ = tf.identity(p3, name=_TFOUTPUT2)
3282+
return p1_, p2_, p3_
3283+
self._run_test_case(func, [_OUTPUT, _OUTPUT1, _OUTPUT2], {_INPUT: data_val, _INPUT1: part_val})
3284+
32613285
@check_opset_min_version(10, "Conv2DBackpropInput")
32623286
def test_Conv2DBackpropInput_const(self):
32633287
input_sizes_val_ = np.array([1, 10, 10, 3], dtype=np.int32)

tf2onnx/onnx_opset/tensor.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1790,6 +1790,34 @@ def version_11(cls, ctx, node, **kwargs):
17901790
# FIXME: the indices in onnx are not the same as in tensorflow.
17911791

17921792

1793+
@tf_op("DynamicPartition")
1794+
class DynamicPartition:
1795+
@classmethod
1796+
def version_9(cls, ctx, node, **kwargs):
1797+
# For desired behavior, see diagram: https://www.tensorflow.org/api_docs/python/tf/raw_ops/DynamicPartition
1798+
data_inp = node.input[0]
1799+
partition_inp = node.input[1]
1800+
partition_shape = ctx.get_shape(partition_inp)
1801+
num_partitions = node.get_attr_value('num_partitions')
1802+
utils.make_sure(partition_shape is not None, "DynamicPartition requires known rank")
1803+
utils.make_sure(len(partition_shape) == 1, "DynamicPartition only implemented for partitions of rank 1")
1804+
# Put partitions into OneHot format
1805+
range_val = np.arange(num_partitions, dtype=np.int32).reshape([num_partitions, 1])
1806+
range_const = ctx.make_const(utils.make_name('range_const'), range_val)
1807+
equal_node = ctx.make_node("Equal", [partition_inp, range_const.output[0]])
1808+
# Cast bool to int since ORT doesn't implement Split on bool.
1809+
equal_int32 = ctx.make_node("Cast", [equal_node.output[0]], attr={"to": TensorProto.INT32})
1810+
split_node = ctx.make_node("Split", [equal_int32.output[0]], output_count=num_partitions, attr={'axis': 0})
1811+
for i in range(num_partitions):
1812+
cond_bools = ctx.make_node("Cast", [split_node.output[i]], attr={"to": TensorProto.BOOL})
1813+
squeeze_node = ctx.make_node("Squeeze", [cond_bools.output[0]], attr={'axes': [0]})
1814+
compress_node = ctx.make_node("Compress", [data_inp, squeeze_node.output[0]], attr={'axis': 0})
1815+
ctx.replace_all_inputs(node.output[i], compress_node.output[0])
1816+
ctx.copy_dtype(node.output[i], compress_node.output[0])
1817+
ctx.copy_shape(node.output[i], compress_node.output[0])
1818+
ctx.remove_node(node.name)
1819+
1820+
17931821
@tf_op("MatrixDiagPart")
17941822
class MatrixDiagPart:
17951823
@classmethod

0 commit comments

Comments
 (0)