@@ -132,7 +132,7 @@ def version_6(cls, ctx, node, **kwargs):
132
132
node .type = "Sum"
133
133
134
134
135
- @tf_op ("SegmentSum" )
135
+ @tf_op ([ "SegmentSum" , "SegmentProd" , "SegmentMax" , "SegmentMin" ] )
136
136
class SegmentSum ():
137
137
@classmethod
138
138
def version_9 (cls , ctx , node , ** kwargs ):
@@ -143,20 +143,48 @@ def version_9(cls, ctx, node, **kwargs):
143
143
data_rank = len (data_shape )
144
144
data_np_dtype = utils .map_onnx_to_numpy_type (ctx .get_dtype (data_inp ))
145
145
seg_np_dtype = utils .map_onnx_to_numpy_type (ctx .get_dtype (segment_inp ))
146
+ data_is_float = np .dtype (data_np_dtype ).kind == 'f'
147
+ data_is_int = np .dtype (data_np_dtype ).kind == 'i'
148
+ utils .make_sure (data_is_float or data_is_int , "dtype for Segment ops must be float or int" )
149
+
150
+ if node .type == "SegmentSum" :
151
+ onnx_op = "ReduceSum"
152
+ identity_value = np .array (0 , dtype = data_np_dtype )
153
+ elif node .type == "SegmentProd" :
154
+ onnx_op = "ReduceProd"
155
+ identity_value = np .array (1 , dtype = data_np_dtype )
156
+ elif node .type == "SegmentMax" :
157
+ onnx_op = "ReduceMax"
158
+ if data_is_float :
159
+ identity_value = np .array ('-inf' , dtype = data_np_dtype )
160
+ else :
161
+ identity_value = np .iinfo (data_np_dtype ).min
162
+ elif node .type == "SegmentMin" :
163
+ onnx_op = "ReduceMin"
164
+ if data_is_float :
165
+ identity_value = np .array ('inf' , dtype = data_np_dtype )
166
+ else :
167
+ identity_value = np .iinfo (data_np_dtype ).max
168
+
146
169
max_segment = ctx .make_node ("ReduceMax" , [segment_inp ], attr = {'axes' : [0 ], 'keepdims' : 0 })
147
170
one_const = ctx .make_const (utils .make_name ("const_one" ), np .array (1 , dtype = seg_np_dtype ))
171
+ identity_const = ctx .make_const (utils .make_name ("const_identity" ), identity_value )
148
172
num_segments = ctx .make_node ("Add" , [max_segment .output [0 ], one_const .output [0 ]])
149
- onehot_values = ctx .make_const (utils .make_name ("onehot_values" ), np .array ([0 , 1 ], dtype = data_np_dtype ))
150
- one_hot_node = ctx .make_node ("OneHot" , [segment_inp , num_segments .output [0 ], onehot_values .output [0 ]], attr = {'axis' : 0 })
151
- one_hot_unsqueeze = one_hot_node
173
+ # ORT doesn't support bool for OneHot so we use float32 and cast to bool
174
+ onehot_values = ctx .make_const (utils .make_name ("onehot_values" ), np .array ([0 , 1 ], dtype = np .float32 ))
175
+ one_hot_node = ctx .make_node ("OneHot" , [segment_inp , num_segments .output [0 ], onehot_values .output [0 ]],
176
+ attr = {'axis' : 0 })
177
+ one_hot_bool = ctx .make_node ("Cast" , [one_hot_node .output [0 ]], attr = {"to" : onnx_pb .TensorProto .BOOL })
178
+ one_hot_unsqueeze = one_hot_bool
179
+
152
180
if data_rank > 1 :
153
181
new_dims = list (range (2 , 2 + data_rank - 1 ))
154
- one_hot_unsqueeze = ctx .make_node ("Unsqueeze" , [one_hot_node .output [0 ]], attr = {'axes' : new_dims })
182
+ one_hot_unsqueeze = ctx .make_node ("Unsqueeze" , [one_hot_bool .output [0 ]], attr = {'axes' : new_dims })
155
183
156
- mul_node = ctx .make_node ("Mul " , [data_inp , one_hot_unsqueeze .output [0 ]])
184
+ mul_node = ctx .make_node ("Where " , [one_hot_unsqueeze . output [ 0 ], data_inp , identity_const .output [0 ]])
157
185
158
186
shapes = node .output_shapes
159
187
dtypes = node .output_dtypes
160
188
ctx .remove_node (node .name )
161
- sum_node = ctx .make_node ("ReduceSum" , [mul_node .output [0 ]], attr = {'axes' : [1 ], 'keepdims' : 0 },
162
- name = node .name , outputs = node .output , shapes = shapes , dtypes = dtypes )
189
+ ctx .make_node (onnx_op , [mul_node .output [0 ]], attr = {'axes' : [1 ], 'keepdims' : 0 },
190
+ name = node .name , outputs = node .output , shapes = shapes , dtypes = dtypes )
0 commit comments