18
18
from onnx .onnx_pb import TensorProto
19
19
from tf2onnx import utils
20
20
from tf2onnx .handler import tf_op
21
- from tf2onnx .utils import make_sure
22
21
from tf2onnx .tf_loader import find_function
23
22
24
23
@@ -278,6 +277,7 @@ def version_7(cls, ctx, node, **kwargs):
278
277
# T output = Select(bool condition, T x, T y)
279
278
# Select_res = Add(Multiply(Cast(bool condition, float32), T x,),
280
279
# Multiply(Cast(Not(bool condition), float32), T y)).
280
+ # TODO: Fix case where condition is 1-dimensional
281
281
utils .make_sure (len (node .input ) > 1 , "Select with only condition is not supported." )
282
282
positive_cast = ctx .make_node ("Cast" , [node .input [0 ]], name = utils .make_name (node .name ),
283
283
attr = {"to" : TensorProto .FLOAT })
@@ -302,11 +302,13 @@ def version_8(cls, ctx, node, **kwargs):
302
302
303
303
true_data_type = ctx .get_dtype (node .input [1 ])
304
304
true_data_shape = ctx .get_shape (node .input [1 ])
305
- make_sure (true_data_type is not None , "select true data dtype cannot be None" )
306
- make_sure (true_data_shape is not None , "select true data shape cannot be None" )
307
-
308
305
condition_shape = ctx .get_shape (node .input [0 ])
309
- utils .make_sure (condition_shape is not None , "Shape of {} is None" .format (node .input [0 ]))
306
+
307
+ if true_data_type is None or true_data_shape is None or condition_shape is None :
308
+ # Fallback if shape is unknown
309
+ cls .version_7 (ctx , node , ** kwargs )
310
+ return
311
+
310
312
rank = len (condition_shape )
311
313
312
314
utils .make_sure (rank >= 0 , "rank should be >= 0" )
@@ -361,13 +363,16 @@ def version_9(cls, ctx, node, **kwargs):
361
363
# T1 output = Where(bool condition, T1 x, T1 y)
362
364
# NOTE: condition can be 1-dimension in tensorflow, while in onnx,
363
365
# it should be broadcastable with other two inputs
364
- node . type = "Where"
366
+
365
367
cond_shape = ctx .get_shape (node .input [0 ])
366
- make_sure (cond_shape is not None , "shape of {} is None" .format (node .input [0 ]))
367
368
input_shape = ctx .get_shape (node .input [1 ])
368
369
if input_shape is None :
369
370
input_shape = ctx .get_shape (node .input [2 ])
370
- make_sure (input_shape is not None , "input shape of {} is None" .format (node .name ))
371
+ if cond_shape is None or input_shape is None :
372
+ # Fallback if shape is unknown
373
+ cls .version_7 (ctx , node , ** kwargs )
374
+ return
375
+ node .type = "Where"
371
376
input_rank = len (input_shape )
372
377
# if cond shape is 1-dimensional while input has higher rank, need to be reshaped to broadcast
373
378
if len (cond_shape ) == 1 and input_rank > 1 :
0 commit comments