Skip to content

Commit 7439397

Browse files
Implemented conversion of DynamicStitch
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 64de3ae commit 7439397

File tree

2 files changed

+49
-0
lines changed

2 files changed

+49
-0
lines changed

tests/test_backend.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3282,6 +3282,16 @@ def func(data, partitions):
32823282
return p1_, p2_, p3_
32833283
self._run_test_case(func, [_OUTPUT, _OUTPUT1, _OUTPUT2], {_INPUT: data_val, _INPUT1: part_val})
32843284

3285+
@check_opset_min_version(11, "ScatterElements")
3286+
def test_dynamic_stitch_both_vector(self):
3287+
data_val = np.array([[5, 1, 3], [7, 2, 4]], dtype=np.float32)
3288+
indices_val = np.array([[0, 1, 4], [2, 3, 5]], dtype=np.int32)
3289+
def func(indices, data):
3290+
x = tf.dynamic_stitch(tf.unstack(indices), tf.unstack(data))
3291+
x_ = tf.identity(x, name=_TFOUTPUT)
3292+
return x_
3293+
self._run_test_case(func, [_OUTPUT], {_INPUT: indices_val, _INPUT1: data_val})
3294+
32853295
@check_opset_min_version(10, "Conv2DBackpropInput")
32863296
def test_Conv2DBackpropInput_const(self):
32873297
input_sizes_val_ = np.array([1, 10, 10, 3], dtype=np.int32)

tf2onnx/onnx_opset/tensor.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1818,6 +1818,45 @@ def version_9(cls, ctx, node, **kwargs):
18181818
ctx.remove_node(node.name)
18191819

18201820

1821+
@tf_op(["DynamicStitch", "ParallelDynamicStitch"])
1822+
class DynamicStitch:
1823+
@classmethod
1824+
def version_10(cls, ctx, node, **kwargs):
1825+
num_partitions = len(node.input) // 2
1826+
index_inputs = node.input[:num_partitions]
1827+
data_inputs = node.input[num_partitions:]
1828+
index_shapes = [ctx.get_shape(inp) for inp in index_inputs]
1829+
data_shapes = [ctx.get_shape(inp) for inp in data_inputs]
1830+
utils.make_sure(all(s is not None and len(s) == 1 for s in index_shapes),
1831+
"DynamicPartition only implemented for index tensors of rank 1")
1832+
utils.make_sure(all(s is not None and len(s) == 1 for s in data_shapes),
1833+
"DynamicPartition only implemented for data tensors of rank 1")
1834+
dtype = ctx.get_dtype(node.output[0])
1835+
concat_indices = ctx.make_node("Concat", index_inputs, attr={'axis': 0})
1836+
concat_indices_int64 = ctx.make_node("Cast", [concat_indices.output[0]], attr={"to": TensorProto.INT64})
1837+
1838+
concat_data = ctx.make_node("Concat", data_inputs, attr={'axis': 0})
1839+
1840+
data_shape = ctx.make_node("Shape", [concat_data.output[0]])
1841+
expanded_indices = ctx.make_node("Expand", [concat_indices_int64.output[0], data_shape.output[0]])
1842+
1843+
max_index = ctx.make_node("ReduceMax", [concat_indices_int64.output[0]], attr={'axes': [0], 'keepdims': 1})
1844+
const_one = ctx.make_const(utils.make_name('const_one'), np.array([1], np.int64))
1845+
target_length = ctx.make_node("Add", [max_index.output[0], const_one.output[0]])
1846+
1847+
zero_tensor = helper.make_tensor("value", dtype, dims=[1], vals=[0])
1848+
zeros_of_shape = ctx.make_node("ConstantOfShape", [target_length.output[0]], attr={"value": zero_tensor})
1849+
1850+
name = node.name
1851+
outputs = node.output
1852+
ctx.remove_node(node.name)
1853+
ctx.make_node("ScatterElements",
1854+
[zeros_of_shape.output[0], expanded_indices.output[0], concat_data.output[0]],
1855+
name=name,
1856+
outputs=outputs,
1857+
attr={'axis': 0})
1858+
1859+
18211860
@tf_op("MatrixDiagPart")
18221861
class MatrixDiagPart:
18231862
@classmethod

0 commit comments

Comments
 (0)