Skip to content

Commit d95efa4

Browse files
committed
Check unsupported dtype for OneHot
1 parent c4c3ccd commit d95efa4

File tree

3 files changed

+22
-3
lines changed

3 files changed

+22
-3
lines changed

onnx_tf/handlers/backend/onehot.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import copy
22
import tensorflow as tf
33

4+
from onnx_tf.common import exception
45
from onnx_tf.handlers.backend_handler import BackendHandler
56
from onnx_tf.handlers.handler import onnx_op
67
from onnx_tf.handlers.handler import tf_func
@@ -10,19 +11,34 @@
1011
@tf_func(tf.one_hot)
1112
class OneHot(BackendHandler):
1213

14+
@classmethod
15+
def args_check(cls, node, **kwargs):
16+
tensor_dict = kwargs["tensor_dict"]
17+
indices = tensor_dict[node.inputs[0]]
18+
depth = tensor_dict[node.inputs[1]]
19+
if indices.dtype not in [tf.uint8, tf.int32, tf.int64]:
20+
exception.OP_UNSUPPORTED_EXCEPT(
21+
"OneHot indices must be in uint8 or int32 or int64 " +
22+
"but it is currently in " + str(indices.dtype) + " which",
23+
"Tensorflow")
24+
if depth.dtype not in [tf.int32]:
25+
exception.OP_UNSUPPORTED_EXCEPT(
26+
"OneHot depth must be in int32 but it is currently in " + str(
27+
depth.dtype) + " which", "Tensorflow")
28+
1329
@classmethod
1430
def version_9(cls, node, **kwargs):
1531
attrs = copy.deepcopy(node.attrs)
1632
tensor_dict = kwargs["tensor_dict"]
1733
indices = tensor_dict[node.inputs[0]]
18-
depth = tensor_dict[node.inputs[1]][0]
34+
depth = tensor_dict[node.inputs[1]]
1935
off_value = tensor_dict[node.inputs[2]][0]
2036
on_value = tensor_dict[node.inputs[2]][1]
2137
attrs["dtype"] = on_value.dtype
2238
return [
2339
cls.make_tensor_from_onnx_node(
2440
node,
25-
inputs=[indices, depth, on_value, off_value],
41+
inputs=[indices, depth[0], on_value, off_value],
2642
attrs=attrs,
2743
**kwargs)
2844
]

onnx_tf/handlers/frontend/onehot.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ def version_9(cls, node, **kwargs):
2727
depth = node.inputs[1]
2828
axis = node.attr.get('axis', -1)
2929

30-
import pdb; pdb.set_trace()
3130
on_value = kwargs['consts'][node.inputs[2]].item(0)
3231
off_value = kwargs['consts'][node.inputs[3]].item(0)
3332
values = np.array([off_value, on_value])

test/backend/test_onnx_backend.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@
3737
backend_test.exclude(r'test_mod_[a-z,_]*uint[0-9]+')
3838
backend_test.exclude(r'test_mod_[a-z,_]*int(8|(16))+')
3939

40+
# TF only support uint8, int32, int64 for indices and int32 for depth in
41+
# tf.one_hot
42+
backend_test.exclude(r'test_onehot_[a-z,_]*')
43+
4044
if legacy_opset_pre_ver(7):
4145
backend_test.exclude(r'[a-z,_]*Upsample[a-z,_]*')
4246

0 commit comments

Comments
 (0)