@@ -46,7 +46,9 @@ def tflist_to_onnx(node_list, shape_override):
46
46
ignored_attr = ["unknown_rank" , "_class" , "Tshape" , "use_cudnn_on_gpu" , "Index" , "Tpaddings" ,
47
47
"TI" , "Tparams" , "Tindices" , "Tlen" , "Tdim" , "dynamic_size" , "Tmultiples" ,
48
48
"Tblock_shape" , "Tcrops" , "index_type" , "Taxis" , "U" , "maxval" ,
49
- "Tout" , "Tlabels" , "Tindex" , "element_shape" , "Targmax" , "T_threshold" ]
49
+ "Tout" , "Tlabels" , "Tindex" , "element_shape" , "Targmax" , "T_threshold" ,
50
+ "output_types" , "output_shapes" , "key_dtype" , "value_dtype" , "Tin" , "Tout" ]
51
+
50
52
# some stats
51
53
op_cnt = collections .Counter ()
52
54
attr_cnt = collections .Counter ()
@@ -80,8 +82,7 @@ def tflist_to_onnx(node_list, shape_override):
80
82
if dtype :
81
83
if not isinstance (dtype , list ):
82
84
dtypes [node .name ] = utils .map_tf_dtype (dtype )
83
- elif a in ["output_type" , "output_dtype" , "out_type" , "Tidx" ,
84
- "out_idx" , "key_dtype" , "value_dtype" , "Tin" , "Tout" ]:
85
+ elif a in ["output_type" , "output_dtype" , "out_type" , "Tidx" , "out_idx" ]:
85
86
# Tidx is used by Range
86
87
# out_idx is used by ListDiff
87
88
attr [a ] = utils .map_tf_dtype (utils .get_tf_node_attr (node , a ))
@@ -100,10 +101,6 @@ def tflist_to_onnx(node_list, shape_override):
100
101
continue
101
102
elif a in ignored_attr :
102
103
continue
103
- elif a == "output_types" :
104
- attr [a ] = [utils .map_tf_dtype (v ) for v in utils .get_tf_node_attr (node , a )]
105
- elif a == "output_shapes" :
106
- attr [a ] = utils .get_tf_output_shapes_attr (node )
107
104
else :
108
105
attr [a ] = utils .get_tf_node_attr (node , a )
109
106
0 commit comments