Skip to content

Commit 5c83131

Browse files
committed
honor output_dtype in multinominal; cleanup dtype mapping
1 parent 2837878 commit 5c83131

File tree

3 files changed

+17
-21
lines changed

3 files changed

+17
-21
lines changed

tests/test_backend.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def test_trig_ops(self):
175175
def test_multinomial(self):
176176
x_val = make_xval([3, 4])
177177
x = tf.placeholder(tf.float32, shape=x_val.shape, name=_TFINPUT)
178-
op = tf.multinomial(x, 2)
178+
op = tf.multinomial(x, 2, output_dtype=tf.int64)
179179
output = tf.identity(op, name=_TFOUTPUT)
180180
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
181181
self.assertAllClose(expected, actual, rtol=1e-06)
@@ -726,11 +726,11 @@ def test_cast(self):
726726

727727
@unittest.skip
728728
def test_onehot(self):
729-
# FIXME via onnx-ml ?
729+
# no such op in onnx
730730
x_val = np.array([0, 1, 2], dtype=np.int32)
731731
depth = 3
732732
x = tf.placeholder(tf.int32, x_val.shape, name=_TFINPUT)
733-
x_ = tf.one_hot(x, depth)
733+
x_ = tf.one_hot(x, depth, on_value=1, axis=0, off_value=0)
734734
output = tf.identity(x_, name=_TFOUTPUT)
735735
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
736736
self.assertAllClose(expected, actual)

tf2onnx/tfonnx.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -66,20 +66,14 @@ def tensorflow_to_onnx(graph):
6666
for a in node.node_def.attr:
6767
attr_cnt[a] += 1
6868
if a == "dtype":
69-
attr[a] = utils.get_tf_dtype(node)
69+
attr[a] = utils.map_tf_dtype(node.get_attr("dtype"))
7070
elif a == "T":
7171
dtype = node.get_attr("T")
7272
if dtype:
7373
if not isinstance(dtype, list):
74-
dtypes[node.name] = utils.TF_TO_ONNX_DTYPE.get(dtype)
75-
elif a == "output_type":
76-
out_type = node.get_attr("output_type")
77-
out_type = utils.TF_TO_ONNX_DTYPE[out_type]
78-
attr[a] = out_type
79-
elif a == "out_type":
80-
out_type = node.get_attr("out_type")
81-
out_type = utils.TF_TO_ONNX_DTYPE[out_type]
82-
attr[a] = out_type
74+
dtypes[node.name] = utils.map_tf_dtype(dtype)
75+
elif a in ["output_type", "output_dtype", "out_type"]:
76+
attr[a] = utils.map_tf_dtype(node.get_attr(a))
8377
elif a == "shape":
8478
attr[a] = utils.get_shape(node)
8579
elif a == "Tperm":
@@ -90,9 +84,7 @@ def tensorflow_to_onnx(graph):
9084
onnx_tensor = utils.tf_to_onnx_tensor(node.get_attr(a), name=node.name + ":0")
9185
attr[a] = onnx_tensor
9286
elif a == "DstT":
93-
dst = node.get_attr("DstT")
94-
dst = tf2onnx.utils.TF_TO_ONNX_DTYPE[dst]
95-
attr["to"] = dst
87+
attr["to"] = utils.map_tf_dtype(node.get_attr("DstT"))
9688
elif a == "SrcT":
9789
continue
9890
elif a in ignored_attr:
@@ -768,12 +760,19 @@ def upsample_op(ctx, node, name, args):
768760
ctx.remove_input(node, node.input[1])
769761
return node
770762

763+
771764
def multinomial_op(ctx, node, name, args):
772765
# output_dtype output = Multinomial(T logits, int32 num_samples, @int seed, @int seed2, @type output_dtype)
773766
sample_size = node.inputs[1].get_tensor_value()
774767
seed = node.get_attr("seed")
775768
if seed:
776769
node.set_attr("seed", float(seed.i))
770+
output_dtype = node.get_attr("output_dtype")
771+
if output_dtype:
772+
output_dtype = output_dtype.i
773+
else:
774+
output_dtype = onnx_pb.TensorProto.INT32
775+
node.set_attr("dtype", output_dtype)
777776
node.set_attr("sample_size", sample_size[0])
778777
ctx.remove_input(node, node.input[1])
779778
return node

tf2onnx/utils.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@
7575
'min', 'seed', 'ends', 'paddings', 'to', 'gamma', 'width_scale', 'normalize_variance', 'group', 'ratio', 'values',
7676
'dtype', 'output_shape', 'spatial', 'split', 'input_forget', 'keepdims', 'transA', 'auto_pad', 'border', 'low',
7777
'linear_before_reset', 'height_scale', 'output_padding', 'shape', 'kernel_shape', 'epsilon', 'size', 'starts',
78-
'direction', 'max', 'clip', 'across_channels', 'value', 'strides', 'extra_shape', 'scales', 'k'
78+
'direction', 'max', 'clip', 'across_channels', 'value', 'strides', 'extra_shape', 'scales', 'k', 'sample_size'
7979
}
8080

8181

@@ -148,10 +148,7 @@ def get_shape(node):
148148
pass
149149
return dims
150150

151-
152-
def get_tf_dtype(node):
153-
"""Get dtype from tensorflow node."""
154-
dtype = node.get_attr("dtype")
151+
def map_tf_dtype(dtype):
155152
if dtype:
156153
dtype = TF_TO_ONNX_DTYPE[dtype]
157154
return dtype

0 commit comments

Comments
 (0)