@@ -145,13 +145,21 @@ def version_7(cls, ctx, node, **kwargs):
145
145
utils .make_sure (len (node .input ) > 1 , "Select with only condition is not supported." )
146
146
dtype = ctx .get_dtype (node .output [0 ])
147
147
utils .make_sure (dtype != TensorProto .STRING , "Select with dtype string requires opset 9" )
148
+ tmp_dtype = dtype
149
+ if tmp_dtype == TensorProto .BOOL :
150
+ tmp_dtype = TensorProto .INT32
148
151
149
152
cond_shape = ctx .get_shape (node .input [0 ])
150
153
input_shape = ctx .get_shape (node .input [1 ])
151
154
if input_shape is None :
152
155
input_shape = ctx .get_shape (node .input [2 ])
153
156
input_rank = len (input_shape ) if input_shape is not None else None
154
157
cond_rank = len (cond_shape ) if cond_shape is not None else None
158
+ true_inp = node .input [1 ]
159
+ false_inp = node .input [2 ]
160
+ if tmp_dtype != dtype :
161
+ true_inp = ctx .make_node ("Cast" , [true_inp ], op_name_scope = node .name , attr = {"to" : tmp_dtype }).output [0 ]
162
+ false_inp = ctx .make_node ("Cast" , [false_inp ], op_name_scope = node .name , attr = {"to" : tmp_dtype }).output [0 ]
155
163
# if cond shape is 1-dimensional while input has higher rank, need to be reshaped to broadcast
156
164
if node .type == "Select" and cond_rank == 1 and input_rank != 1 :
157
165
utils .make_sure (input_rank is not None , "input_rank unknown and cond_rank == 1" )
@@ -161,18 +169,20 @@ def version_7(cls, ctx, node, **kwargs):
161
169
ctx .replace_input (node , node .input [0 ], reshape .output [0 ], 0 )
162
170
163
171
positive_cast = ctx .make_node ("Cast" , [node .input [0 ]], name = utils .make_name (node .name ),
164
- attr = {"to" : dtype })
172
+ attr = {"to" : tmp_dtype })
165
173
negative = ctx .make_node ("Not" , [node .input [0 ]], name = utils .make_name (node .name ))
166
174
negative_cast = ctx .make_node ("Cast" , [negative .output [0 ]], name = utils .make_name (node .name ),
167
- attr = {"to" : dtype })
168
- multiply_1 = ctx .make_node ("Mul" , [positive_cast .output [0 ], node . input [ 1 ] ], name = utils .make_name (node .name ))
169
- multiply_2 = ctx .make_node ("Mul" , [node . input [ 2 ] , negative_cast .output [0 ]], name = utils .make_name (node .name ))
175
+ attr = {"to" : tmp_dtype })
176
+ multiply_1 = ctx .make_node ("Mul" , [positive_cast .output [0 ], true_inp ], name = utils .make_name (node .name ))
177
+ multiply_2 = ctx .make_node ("Mul" , [false_inp , negative_cast .output [0 ]], name = utils .make_name (node .name ))
170
178
add_name = node .name
171
179
add_out = node .output
172
180
shape = ctx .get_shape (node .output [0 ])
173
181
ctx .remove_node (node .name )
174
182
ctx .make_node ("Add" , [multiply_1 .output [0 ], multiply_2 .output [0 ]], outputs = add_out , name = add_name ,
175
- dtypes = [dtype ], shapes = [shape ])
183
+ dtypes = [tmp_dtype ], shapes = [shape ])
184
+ if tmp_dtype != dtype :
185
+ ctx .insert_new_node_on_output ("Cast" , node .output [0 ], to = dtype )
176
186
177
187
@classmethod
178
188
def version_9 (cls , ctx , node , ** kwargs ):
0 commit comments