@@ -66,20 +66,14 @@ def tensorflow_to_onnx(graph):
66
66
for a in node .node_def .attr :
67
67
attr_cnt [a ] += 1
68
68
if a == "dtype" :
69
- attr [a ] = utils .get_tf_dtype (node )
69
+ attr [a ] = utils .map_tf_dtype (node . get_attr ( "dtype" ) )
70
70
elif a == "T" :
71
71
dtype = node .get_attr ("T" )
72
72
if dtype :
73
73
if not isinstance (dtype , list ):
74
- dtypes [node .name ] = utils .TF_TO_ONNX_DTYPE .get (dtype )
75
- elif a == "output_type" :
76
- out_type = node .get_attr ("output_type" )
77
- out_type = utils .TF_TO_ONNX_DTYPE [out_type ]
78
- attr [a ] = out_type
79
- elif a == "out_type" :
80
- out_type = node .get_attr ("out_type" )
81
- out_type = utils .TF_TO_ONNX_DTYPE [out_type ]
82
- attr [a ] = out_type
74
+ dtypes [node .name ] = utils .map_tf_dtype (dtype )
75
+ elif a in ["output_type" , "output_dtype" , "out_type" ]:
76
+ attr [a ] = utils .map_tf_dtype (node .get_attr (a ))
83
77
elif a == "shape" :
84
78
attr [a ] = utils .get_shape (node )
85
79
elif a == "Tperm" :
@@ -90,9 +84,7 @@ def tensorflow_to_onnx(graph):
90
84
onnx_tensor = utils .tf_to_onnx_tensor (node .get_attr (a ), name = node .name + ":0" )
91
85
attr [a ] = onnx_tensor
92
86
elif a == "DstT" :
93
- dst = node .get_attr ("DstT" )
94
- dst = tf2onnx .utils .TF_TO_ONNX_DTYPE [dst ]
95
- attr ["to" ] = dst
87
+ attr ["to" ] = utils .map_tf_dtype (node .get_attr ("DstT" ))
96
88
elif a == "SrcT" :
97
89
continue
98
90
elif a in ignored_attr :
@@ -768,12 +760,19 @@ def upsample_op(ctx, node, name, args):
768
760
ctx .remove_input (node , node .input [1 ])
769
761
return node
770
762
763
+
771
764
def multinomial_op (ctx , node , name , args ):
772
765
# output_dtype output = Multinomial(T logits, int32 num_samples, @int seed, @int seed2, @type output_dtype)
773
766
sample_size = node .inputs [1 ].get_tensor_value ()
774
767
seed = node .get_attr ("seed" )
775
768
if seed :
776
769
node .set_attr ("seed" , float (seed .i ))
770
+ output_dtype = node .get_attr ("output_dtype" )
771
+ if output_dtype :
772
+ output_dtype = output_dtype .i
773
+ else :
774
+ output_dtype = onnx_pb .TensorProto .INT32
775
+ node .set_attr ("dtype" , output_dtype )
777
776
node .set_attr ("sample_size" , sample_size [0 ])
778
777
ctx .remove_input (node , node .input [1 ])
779
778
return node
0 commit comments