@@ -274,6 +274,27 @@ def version_11(cls, ctx, node, **kwargs):
274
274
275
275
@tf_op (["Select" , "SelectV2" ])
276
276
class Select :
277
+ @classmethod
278
+ def version_7 (cls , ctx , node , ** kwargs ):
279
+ # T output = Select(bool condition, T x, T y)
280
+ # Select_res = Add(Multiply(Cast(bool condition, float32), T x,),
281
+ # Multiply(Cast(Not(bool condition), float32), T y)).
282
+ utils .make_sure (len (node .input ) > 1 , "Select with only condition is not supported." )
283
+ positive_cast = ctx .make_node ("Cast" , [node .input [0 ]], name = utils .make_name (node .name ),
284
+ attr = {"to" : TensorProto .FLOAT })
285
+ negative = ctx .make_node ("Not" , [node .input [0 ]], name = utils .make_name (node .name ))
286
+ negative_cast = ctx .make_node ("Cast" , [negative .output [0 ]], name = utils .make_name (node .name ),
287
+ attr = {"to" : TensorProto .FLOAT })
288
+ multiply_1 = ctx .make_node ("Mul" , [positive_cast .output [0 ], node .input [1 ]], name = utils .make_name (node .name ))
289
+ multiply_2 = ctx .make_node ("Mul" , [node .input [2 ], negative_cast .output [0 ]], name = utils .make_name (node .name ))
290
+ add_name = node .name
291
+ add_out = node .output
292
+ dtype = ctx .get_dtype (node .output [0 ])
293
+ shape = ctx .get_shape (node .output [0 ])
294
+ ctx .remove_node (node .name )
295
+ ctx .make_node ("Add" , [multiply_1 .output [0 ], multiply_2 .output [0 ]], outputs = add_out , name = add_name ,
296
+ dtypes = [dtype ], shapes = [shape ])
297
+
277
298
@classmethod
278
299
def version_8 (cls , ctx , node , ** kwargs ):
279
300
# T output = Select(bool condition, T x, T y)
0 commit comments