@@ -130,3 +130,61 @@ class AddN():
130
130
@classmethod
131
131
def version_6 (cls , ctx , node , ** kwargs ):
132
132
node .type = "Sum"
133
+
134
+
135
+ @tf_op (["SegmentSum" , "SegmentProd" , "SegmentMax" , "SegmentMin" ])
136
+ class SegmentSum ():
137
+ @classmethod
138
+ def version_9 (cls , ctx , node , ** kwargs ):
139
+ data_inp = node .input [0 ]
140
+ segment_inp = node .input [1 ]
141
+ data_shape = ctx .get_shape (data_inp )
142
+ utils .make_sure (data_shape is not None , "Segment ops require input rank to be known" )
143
+ data_rank = len (data_shape )
144
+ data_np_dtype = utils .map_onnx_to_numpy_type (ctx .get_dtype (data_inp ))
145
+ seg_np_dtype = utils .map_onnx_to_numpy_type (ctx .get_dtype (segment_inp ))
146
+ data_is_float = np .dtype (data_np_dtype ).kind == 'f'
147
+ data_is_int = np .dtype (data_np_dtype ).kind == 'i'
148
+ utils .make_sure (data_is_float or data_is_int , "dtype for Segment ops must be float or int" )
149
+
150
+ if node .type == "SegmentSum" :
151
+ onnx_op = "ReduceSum"
152
+ identity_value = np .array (0 , dtype = data_np_dtype )
153
+ elif node .type == "SegmentProd" :
154
+ onnx_op = "ReduceProd"
155
+ identity_value = np .array (1 , dtype = data_np_dtype )
156
+ elif node .type == "SegmentMax" :
157
+ onnx_op = "ReduceMax"
158
+ if data_is_float :
159
+ identity_value = np .array ('-inf' , dtype = data_np_dtype )
160
+ else :
161
+ identity_value = np .iinfo (data_np_dtype ).min
162
+ elif node .type == "SegmentMin" :
163
+ onnx_op = "ReduceMin"
164
+ if data_is_float :
165
+ identity_value = np .array ('inf' , dtype = data_np_dtype )
166
+ else :
167
+ identity_value = np .iinfo (data_np_dtype ).max
168
+
169
+ max_segment = ctx .make_node ("ReduceMax" , [segment_inp ], attr = {'axes' : [0 ], 'keepdims' : 0 })
170
+ one_const = ctx .make_const (utils .make_name ("const_one" ), np .array (1 , dtype = seg_np_dtype ))
171
+ identity_const = ctx .make_const (utils .make_name ("const_identity" ), identity_value )
172
+ num_segments = ctx .make_node ("Add" , [max_segment .output [0 ], one_const .output [0 ]])
173
+ # ORT doesn't support bool for OneHot so we use float32 and cast to bool
174
+ onehot_values = ctx .make_const (utils .make_name ("onehot_values" ), np .array ([0 , 1 ], dtype = np .float32 ))
175
+ one_hot_node = ctx .make_node ("OneHot" , [segment_inp , num_segments .output [0 ], onehot_values .output [0 ]],
176
+ attr = {'axis' : 0 })
177
+ one_hot_bool = ctx .make_node ("Cast" , [one_hot_node .output [0 ]], attr = {"to" : onnx_pb .TensorProto .BOOL })
178
+ one_hot_unsqueeze = one_hot_bool
179
+
180
+ if data_rank > 1 :
181
+ new_dims = list (range (2 , 2 + data_rank - 1 ))
182
+ one_hot_unsqueeze = ctx .make_node ("Unsqueeze" , [one_hot_bool .output [0 ]], attr = {'axes' : new_dims })
183
+
184
+ mul_node = ctx .make_node ("Where" , [one_hot_unsqueeze .output [0 ], data_inp , identity_const .output [0 ]])
185
+
186
+ shapes = node .output_shapes
187
+ dtypes = node .output_dtypes
188
+ ctx .remove_node (node .name )
189
+ ctx .make_node (onnx_op , [mul_node .output [0 ]], attr = {'axes' : [1 ], 'keepdims' : 0 },
190
+ name = node .name , outputs = node .output , shapes = shapes , dtypes = dtypes )
0 commit comments