Skip to content

Commit fdabeea

Browse files
authored
Merge pull request #851 from PreethaVeera/Preetha/Select-opset7
Add support for Select op with Opset 7
2 parents 0a9db3f + 2331d60 commit fdabeea

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

tf2onnx/onnx_opset/controlflow.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,27 @@ def version_11(cls, ctx, node, **kwargs):
274274

275275
@tf_op(["Select", "SelectV2"])
276276
class Select:
277+
@classmethod
278+
def version_7(cls, ctx, node, **kwargs):
279+
# T output = Select(bool condition, T x, T y)
280+
# Select_res = Add(Multiply(Cast(bool condition, float32), T x,),
281+
# Multiply(Cast(Not(bool condition), float32), T y)).
282+
utils.make_sure(len(node.input) > 1, "Select with only condition is not supported.")
283+
positive_cast = ctx.make_node("Cast", [node.input[0]], name=utils.make_name(node.name),
284+
attr={"to": TensorProto.FLOAT})
285+
negative = ctx.make_node("Not", [node.input[0]], name=utils.make_name(node.name))
286+
negative_cast = ctx.make_node("Cast", [negative.output[0]], name=utils.make_name(node.name),
287+
attr={"to": TensorProto.FLOAT})
288+
multiply_1 = ctx.make_node("Mul", [positive_cast.output[0], node.input[1]], name=utils.make_name(node.name))
289+
multiply_2 = ctx.make_node("Mul", [node.input[2], negative_cast.output[0]], name=utils.make_name(node.name))
290+
add_name = node.name
291+
add_out = node.output
292+
dtype = ctx.get_dtype(node.output[0])
293+
shape = ctx.get_shape(node.output[0])
294+
ctx.remove_node(node.name)
295+
ctx.make_node("Add", [multiply_1.output[0], multiply_2.output[0]], outputs=add_out, name=add_name,
296+
dtypes=[dtype], shapes=[shape])
297+
277298
@classmethod
278299
def version_8(cls, ctx, node, **kwargs):
279300
# T output = Select(bool condition, T x, T y)

0 commit comments

Comments
 (0)