@@ -1463,17 +1463,16 @@ def fill_op(ctx, node, name, args):
1463
1463
# both shape and value in tensorflow are passed as tensor.
1464
1464
# In onnx the value is an attribute so we need to fetch the value as const which
1465
1465
# sooner or later will be a problem for tensorflow-onnx.
1466
- shape = ctx .get_shape (node .output [0 ])
1467
- utils .make_sure (all (i >= 0 for i in shape ), "shape attr should not be less than zero" )
1468
- value = node .inputs [1 ].get_tensor_value ()
1469
- value_proto = numpy_helper .from_array (node .inputs [1 ].get_tensor_value (as_list = False ))
1470
- dtype = value_proto .data_type
1471
- # onnx spec says value MUST be float.
1472
- node .set_attr ("value" , float (value ))
1473
- node .set_attr ("shape" , shape )
1474
- node .set_attr ("dtype" , dtype )
1475
- del node .input [:]
1476
- return node
1466
+ # ConstantOfShape in onnxruntime only support int64, so insert cast op
1467
+ input_dtype_is_int64 = utils .ONNX_TO_NUMPY_DTYPE [ctx .get_dtype (node .input [0 ])] == np .int64
1468
+ if not input_dtype_is_int64 :
1469
+ cast_node = ctx .insert_new_node_on_input (node , "Cast" , node .input [0 ], to = onnx_pb .TensorProto .INT64 )
1470
+ dtype = ctx .get_dtype (node .output [0 ])
1471
+ value = np .array ([node .inputs [1 ].get_tensor_value ()]).astype (utils .ONNX_TO_NUMPY_DTYPE [dtype ])
1472
+ value_proto = numpy_helper .from_array (value )
1473
+ node .set_attr ("value" , value_proto )
1474
+ del node .input [1 ]
1475
+ return [node ] if input_dtype_is_int64 else [node , cast_node ]
1477
1476
1478
1477
1479
1478
def reverse_op8 (ctx , node , name , args ):
@@ -1723,6 +1722,18 @@ def logical_compare_op(ctx, node, name, args):
1723
1722
return nodes
1724
1723
1725
1724
1725
+ def where_op (ctx , node , name , args ):
1726
+ # T_y output = Where(T_x condition), return indices of elements whose value are True
1727
+ node .type = "NonZero"
1728
+ # in onnx, indices are returned in this way [[ind_a_0, ind_b_0, ...], [ind_a_1, ind_b_1,...]];
1729
+ # while in tf, the result will be [[ind_a_0, ind_a_1, ...], [ind_b_0, ind_b_1, ...], ...]
1730
+ # this is the reason a transpose node inserted here.
1731
+ transpose_node = ctx .insert_new_node_on_output ("Transpose" , node .output [0 ], name = utils .make_name ("where_op_added" ))
1732
+ ctx .copy_shape (node .output [0 ], transpose_node .output [0 ])
1733
+ ctx .copy_dtype (node .output [0 ], transpose_node .output [0 ])
1734
+ return [node , transpose_node ]
1735
+
1736
+
1726
1737
# map tensorflow ops to onnx ops. The format below is
1727
1738
# "TFOP": func_to_map, ["OnnxOp", ...]
1728
1739
#
@@ -1894,6 +1905,7 @@ def logical_compare_op(ctx, node, name, args):
1894
1905
"Less" : (logical_compare_op , []),
1895
1906
"ResizeBilinear" : (upsample_op9 , ["Upsample" , "linear" ]),
1896
1907
"ResizeNearestNeighbor" : (upsample_op9 , ["Upsample" , "nearest" ]),
1908
+ "Where" : (where_op , []),
1897
1909
}
1898
1910
1899
1911
_OPSETS = [
0 commit comments