Skip to content

Commit 04d2488

Browse files
authored
support for AsString, StringToNumber (#1648)
* support for AsString, StringToNumber Signed-off-by: Guenther Schmuelling <[email protected]> * pylint Signed-off-by: Guenther Schmuelling <[email protected]>
1 parent 998f713 commit 04d2488

File tree

2 files changed

+54
-0
lines changed

2 files changed

+54
-0
lines changed

tests/test_backend.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5541,6 +5541,36 @@ def func(x):
55415541
x_val = np.array([1, 5, 2, 0, 3, 4], dtype=np.int64)
55425542
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
55435543

5544+
@check_tf_min_version("2.1")
5545+
@skip_tflite("TFlite errors on some attributes")
5546+
@check_opset_min_version(9, "string")
5547+
def test_asstring(self):
5548+
def func(x):
5549+
op_ = tf.strings.as_string(x)
5550+
return tf.identity(op_, name=_TFOUTPUT)
5551+
5552+
x_val = np.array([0, 1, 2, 3], dtype=np.int32)
5553+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
5554+
5555+
x_val = np.array([0, 1, 2, 3], dtype=np.float32)
5556+
# can't check the values because in onnx they are padded with 0, in tf they are not
5557+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, check_value=False)
5558+
5559+
@check_tf_min_version("2.1")
5560+
@skip_tflite("TFlite errors on some attributes")
5561+
@check_opset_min_version(9, "string")
5562+
def test_string_to_number(self):
5563+
def func(x):
5564+
op_ = tf.strings.to_number(x)
5565+
return tf.identity(op_, name=_TFOUTPUT)
5566+
5567+
# tf gets this wrong and returns fp32 instead of int
5568+
x_val = np.array("123", dtype=np.object)
5569+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
5570+
5571+
x_val = np.array("123.1", dtype=np.object)
5572+
# can't check the values because in onnx they are padded with 0, in tf they are not
5573+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, check_value=False)
55445574

55455575

55465576
if __name__ == '__main__':

tf2onnx/onnx_opset/tensor.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3799,3 +3799,27 @@ def version_8(cls, ctx, node, **kwargs):
37993799
# broadcast by expanding
38003800
node.type = "Expand"
38013801
ctx.insert_new_node_on_input(node, "Cast", node.input[1], to=TensorProto.INT64)
3802+
3803+
3804+
@tf_op("AsString")
3805+
class AsString:
3806+
@classmethod
3807+
def version_9(cls, ctx, node, **kwargs):
3808+
if (node.get_attr_value("precision") or node.get_attr_value("scientific") or node.get_attr_value("fill")):
3809+
logger.warning(
3810+
"ONNX does not support precision, scientific and fill attributes for AsString")
3811+
shapes = node.output_shapes
3812+
ctx.remove_node(node.name)
3813+
_ = ctx.make_node("Cast", node.input, name=node.name,
3814+
dtypes=[TensorProto.STRING], shapes=shapes, attr={"to": TensorProto.STRING})
3815+
3816+
3817+
@tf_op("StringToNumber")
3818+
class StringToNumber:
3819+
@classmethod
3820+
def version_9(cls, ctx, node, **kwargs):
3821+
shapes = node.output_shapes
3822+
dtypes = node.output_dtypes
3823+
ctx.remove_node(node.name)
3824+
_ = ctx.make_node("Cast", node.input, name=node.name,
3825+
dtypes=dtypes, shapes=shapes, attr={"to": dtypes[0]})

0 commit comments

Comments
 (0)