@@ -130,3 +130,27 @@ class AddN():
130
130
@classmethod
131
131
def version_6 (cls , ctx , node , ** kwargs ):
132
132
node .type = "Sum"
133
+
134
+
135
+ @tf_op ("SegmentSum" )
136
+ class SegmentSum ():
137
+ @classmethod
138
+ def version_9 (cls , ctx , node , ** kwargs ):
139
+ data_inp = node .input [0 ]
140
+ segment_inp = node .input [1 ]
141
+ data_shape = ctx .get_shape (data_inp )
142
+ utils .make_sure (data_shape is not None , "Segment ops require input rank to be known" )
143
+ data_np_dtype = utils .map_onnx_to_numpy_type (ctx .get_dtype (data_inp ))
144
+ seg_np_dtype = utils .map_onnx_to_numpy_type (ctx .get_dtype (segment_inp ))
145
+ max_segment = ctx .make_node ("ReduceMax" , [segment_inp ], attr = {'axes' : [0 ], 'keepdims' : 0 })
146
+ one_const = ctx .make_const (utils .make_name ("const_one" ), np .array (1 , dtype = seg_np_dtype ))
147
+ num_segments = ctx .make_node ("Add" , [max_segment .output [0 ], one_const .output [0 ]])
148
+ onehot_values = ctx .make_const (utils .make_name ("onehot_values" ), np .array ([0 , 1 ], dtype = data_np_dtype ))
149
+ one_hot_node = ctx .make_node ("OneHot" , [segment_inp , num_segments .output [0 ], onehot_values .output [0 ]], attr = {'axis' : 0 })
150
+ mul_node = ctx .make_node ("Mul" , [data_inp , one_hot_node .output [0 ]])
151
+
152
+ shapes = node .output_shapes
153
+ dtypes = node .output_dtypes
154
+ ctx .remove_node (node .name )
155
+ sum_node = ctx .make_node ("ReduceSum" , [mul_node .output [0 ]], attr = {'axes' : [1 ], 'keepdims' : 0 },
156
+ name = node .name , outputs = node .output , shapes = shapes , dtypes = dtypes )
0 commit comments