@@ -29,15 +29,21 @@ def _add_cast_to_inputs(graph, node, supported_dtypes, target_dtype):
29
29
graph .copy_shape (inp , inp_cast .output [0 ])
30
30
graph .set_dtype (inp_cast .output [0 ], target_dtype )
31
31
32
-
33
- def _add_cast_to_same_type_to_inputs (graph , node ):
32
+ def _add_cast_to_same_type_to_inputs (graph , node , supported_dtypes , target_dtype ):
34
33
common_dtype = graph .get_dtype (node .input [0 ])
34
+ if common_dtype not in supported_dtypes :
35
+ common_dtype = target_dtype
35
36
36
- for inp in node .input [ 1 :] :
37
+ for inp in node .input :
37
38
if graph .get_dtype (inp ) != common_dtype :
38
39
inp_cast = graph .insert_new_node_on_input (node , "Cast" , inp , to = common_dtype )
39
40
graph .copy_shape (inp , inp_cast .output [0 ])
40
41
graph .set_dtype (inp_cast .output [0 ], common_dtype )
42
+ if graph .is_const (inp ) and graph .get_tensor_value (inp ) == '' :
43
+ # Convert '' string constant to -1 int
44
+ # https://github.com/tensorflow/tensorflow/blob/4e7f0185c70faf35e12acbfe381a729d1e6cc38c/tensorflow/python/feature_column/feature_column.py#L2286
45
+ const_node = graph .get_node_by_output (inp )
46
+ const_node .set_tensor_value (utils .np .array (- 1 ))
41
47
42
48
43
49
@tf_op ("LogicalNot" , onnx_op = "Not" )
@@ -92,8 +98,24 @@ def version_7(cls, ctx, node, **kwargs):
92
98
93
99
@classmethod
94
100
def version_11 (cls , ctx , node , ** kwargs ):
95
- # starting with opset-11, equal supports all types (but both operands must be of the same type)
96
- _add_cast_to_same_type_to_inputs (ctx , node )
101
+ # starting with opset-11, equal supports all numerical types (but both operands must be of the same type)
102
+ # string type is not supported
103
+ supported_dtypes = [
104
+ TensorProto .BOOL ,
105
+ TensorProto .DOUBLE ,
106
+ TensorProto .FLOAT ,
107
+ TensorProto .FLOAT16 ,
108
+ TensorProto .INT8 ,
109
+ TensorProto .INT16 ,
110
+ TensorProto .INT32 ,
111
+ TensorProto .INT64 ,
112
+ TensorProto .UINT8 ,
113
+ TensorProto .UINT16 ,
114
+ TensorProto .UINT32 ,
115
+ TensorProto .UINT64
116
+ ]
117
+ target_dtype = TensorProto .INT32
118
+ _add_cast_to_same_type_to_inputs (ctx , node , supported_dtypes , target_dtype )
97
119
need_not = node .type == "NotEqual"
98
120
if need_not :
99
121
node .type = "Equal"
0 commit comments