Skip to content

Commit 373908e

Browse files
Merge pull request #1154 from onnx/tom/FixSelectUnknownShape
Fixed regression for SelectV2 when shape is unknown
2 parents 1bc3079 + 15b3fc1 commit 373908e

File tree

1 file changed

+13
-8
lines changed

1 file changed

+13
-8
lines changed

tf2onnx/onnx_opset/controlflow.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from onnx.onnx_pb import TensorProto
1919
from tf2onnx import utils
2020
from tf2onnx.handler import tf_op
21-
from tf2onnx.utils import make_sure
2221
from tf2onnx.tf_loader import find_function
2322

2423

@@ -278,6 +277,7 @@ def version_7(cls, ctx, node, **kwargs):
278277
# T output = Select(bool condition, T x, T y)
279278
# Select_res = Add(Multiply(Cast(bool condition, float32), T x,),
280279
# Multiply(Cast(Not(bool condition), float32), T y)).
280+
# TODO: Fix case where condition is 1-dimensional
281281
utils.make_sure(len(node.input) > 1, "Select with only condition is not supported.")
282282
positive_cast = ctx.make_node("Cast", [node.input[0]], name=utils.make_name(node.name),
283283
attr={"to": TensorProto.FLOAT})
@@ -302,11 +302,13 @@ def version_8(cls, ctx, node, **kwargs):
302302

303303
true_data_type = ctx.get_dtype(node.input[1])
304304
true_data_shape = ctx.get_shape(node.input[1])
305-
make_sure(true_data_type is not None, "select true data dtype cannot be None")
306-
make_sure(true_data_shape is not None, "select true data shape cannot be None")
307-
308305
condition_shape = ctx.get_shape(node.input[0])
309-
utils.make_sure(condition_shape is not None, "Shape of {} is None".format(node.input[0]))
306+
307+
if true_data_type is None or true_data_shape is None or condition_shape is None:
308+
# Fallback if shape is unknown
309+
cls.version_7(ctx, node, **kwargs)
310+
return
311+
310312
rank = len(condition_shape)
311313

312314
utils.make_sure(rank >= 0, "rank should be >= 0")
@@ -361,13 +363,16 @@ def version_9(cls, ctx, node, **kwargs):
361363
# T1 output = Where(bool condition, T1 x, T1 y)
362364
# NOTE: condition can be 1-dimension in tensorflow, while in onnx,
363365
# it should be broadcastable with other two inputs
364-
node.type = "Where"
366+
365367
cond_shape = ctx.get_shape(node.input[0])
366-
make_sure(cond_shape is not None, "shape of {} is None".format(node.input[0]))
367368
input_shape = ctx.get_shape(node.input[1])
368369
if input_shape is None:
369370
input_shape = ctx.get_shape(node.input[2])
370-
make_sure(input_shape is not None, "input shape of {} is None".format(node.name))
371+
if cond_shape is None or input_shape is None:
372+
# Fallback if shape is unknown
373+
cls.version_7(ctx, node, **kwargs)
374+
return
375+
node.type = "Where"
371376
input_rank = len(input_shape)
372377
# if cond shape is 1-dimensional while input has higher rank, need to be reshaped to broadcast
373378
if len(cond_shape) == 1 and input_rank > 1:

0 commit comments

Comments
 (0)