@@ -1818,6 +1818,45 @@ def version_9(cls, ctx, node, **kwargs):
1818
1818
ctx .remove_node (node .name )
1819
1819
1820
1820
1821
+ @tf_op (["DynamicStitch" , "ParallelDynamicStitch" ])
1822
+ class DynamicStitch :
1823
+ @classmethod
1824
+ def version_10 (cls , ctx , node , ** kwargs ):
1825
+ num_partitions = len (node .input ) // 2
1826
+ index_inputs = node .input [:num_partitions ]
1827
+ data_inputs = node .input [num_partitions :]
1828
+ index_shapes = [ctx .get_shape (inp ) for inp in index_inputs ]
1829
+ data_shapes = [ctx .get_shape (inp ) for inp in data_inputs ]
1830
+ utils .make_sure (all (s is not None and len (s ) == 1 for s in index_shapes ),
1831
+ "DynamicPartition only implemented for index tensors of rank 1" )
1832
+ utils .make_sure (all (s is not None and len (s ) == 1 for s in data_shapes ),
1833
+ "DynamicPartition only implemented for data tensors of rank 1" )
1834
+ dtype = ctx .get_dtype (node .output [0 ])
1835
+ concat_indices = ctx .make_node ("Concat" , index_inputs , attr = {'axis' : 0 })
1836
+ concat_indices_int64 = ctx .make_node ("Cast" , [concat_indices .output [0 ]], attr = {"to" : TensorProto .INT64 })
1837
+
1838
+ concat_data = ctx .make_node ("Concat" , data_inputs , attr = {'axis' : 0 })
1839
+
1840
+ data_shape = ctx .make_node ("Shape" , [concat_data .output [0 ]])
1841
+ expanded_indices = ctx .make_node ("Expand" , [concat_indices_int64 .output [0 ], data_shape .output [0 ]])
1842
+
1843
+ max_index = ctx .make_node ("ReduceMax" , [concat_indices_int64 .output [0 ]], attr = {'axes' : [0 ], 'keepdims' : 1 })
1844
+ const_one = ctx .make_const (utils .make_name ('const_one' ), np .array ([1 ], np .int64 ))
1845
+ target_length = ctx .make_node ("Add" , [max_index .output [0 ], const_one .output [0 ]])
1846
+
1847
+ zero_tensor = helper .make_tensor ("value" , dtype , dims = [1 ], vals = [0 ])
1848
+ zeros_of_shape = ctx .make_node ("ConstantOfShape" , [target_length .output [0 ]], attr = {"value" : zero_tensor })
1849
+
1850
+ name = node .name
1851
+ outputs = node .output
1852
+ ctx .remove_node (node .name )
1853
+ ctx .make_node ("ScatterElements" ,
1854
+ [zeros_of_shape .output [0 ], expanded_indices .output [0 ], concat_data .output [0 ]],
1855
+ name = name ,
1856
+ outputs = outputs ,
1857
+ attr = {'axis' : 0 })
1858
+
1859
+
1821
1860
@tf_op ("MatrixDiagPart" )
1822
1861
class MatrixDiagPart :
1823
1862
@classmethod
0 commit comments