|
1 | 1 | import copy |
2 | 2 | import tensorflow as tf |
3 | 3 |
|
| 4 | +from onnx_tf.common import exception |
4 | 5 | from onnx_tf.handlers.backend_handler import BackendHandler |
5 | 6 | from onnx_tf.handlers.handler import onnx_op |
6 | 7 | from onnx_tf.handlers.handler import tf_func |
|
10 | 11 | @tf_func(tf.one_hot) |
11 | 12 | class OneHot(BackendHandler): |
12 | 13 |
|
| 14 | + @classmethod |
| 15 | + def args_check(cls, node, **kwargs): |
| 16 | + tensor_dict = kwargs["tensor_dict"] |
| 17 | + indices = tensor_dict[node.inputs[0]] |
| 18 | + depth = tensor_dict[node.inputs[1]] |
| 19 | + if indices.dtype not in [tf.uint8, tf.int32, tf.int64]: |
| 20 | + exception.OP_UNSUPPORTED_EXCEPT( |
| 21 | + "OneHot indices must be in uint8 or int32 or int64 " + |
| 22 | + "but it is currently in " + str(indices.dtype) + " which", |
| 23 | + "Tensorflow") |
| 24 | + if depth.dtype not in [tf.int32]: |
| 25 | + exception.OP_UNSUPPORTED_EXCEPT( |
| 26 | + "OneHot depth must be in int32 but it is currently in " + str( |
| 27 | + depth.dtype) + " which", "Tensorflow") |
| 28 | + |
13 | 29 | @classmethod |
14 | 30 | def version_9(cls, node, **kwargs): |
15 | 31 | attrs = copy.deepcopy(node.attrs) |
16 | 32 | tensor_dict = kwargs["tensor_dict"] |
17 | 33 | indices = tensor_dict[node.inputs[0]] |
18 | | - depth = tensor_dict[node.inputs[1]][0] |
| 34 | + depth = tensor_dict[node.inputs[1]] |
19 | 35 | off_value = tensor_dict[node.inputs[2]][0] |
20 | 36 | on_value = tensor_dict[node.inputs[2]][1] |
21 | 37 | attrs["dtype"] = on_value.dtype |
22 | 38 | return [ |
23 | 39 | cls.make_tensor_from_onnx_node( |
24 | 40 | node, |
25 | | - inputs=[indices, depth, on_value, off_value], |
| 41 | + inputs=[indices, depth[0], on_value, off_value], |
26 | 42 | attrs=attrs, |
27 | 43 | **kwargs) |
28 | 44 | ] |
0 commit comments