Skip to content

Commit 9d1080f

Browse files
Implement conversions of StringUpper and StringLower (#1343)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 78172d1 commit 9d1080f

File tree

4 files changed

+54
-2
lines changed

4 files changed

+54
-2
lines changed

tests/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def requires_custom_ops(message=""):
207207
""" Skip until custom ops framework is on PyPI. """
208208
reason = _append_message("test needs custom ops framework", message)
209209
try:
210-
import ortcustomops #pylint: disable=import-outside-toplevel,unused-import
210+
import onnxruntime_customops #pylint: disable=import-outside-toplevel,unused-import
211211
can_import = True
212212
except ModuleNotFoundError:
213213
can_import = False

tests/test_backend.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2834,6 +2834,36 @@ def func(x):
28342834
return tf.identity(res, name=_TFOUTPUT)
28352835
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
28362836

2837+
@check_tf_min_version("1.14", "tf.strings.lower")
2838+
@check_opset_min_version(10, "StringNormalizer")
2839+
def test_string_lower(self):
2840+
text_val1 = np.array([["a", "Test 1 2 3", "♠♣"], ["Hi there", "test test", "♥♦"]], dtype=np.str)
2841+
def func(text1):
2842+
x = tf.strings.lower(text1)
2843+
x_ = tf.identity(x, name=_TFOUTPUT)
2844+
return x_
2845+
self._run_test_case(func, [_OUTPUT], {_INPUT: text_val1})
2846+
2847+
@check_tf_min_version("1.14", "tf.strings.lower")
2848+
@check_opset_min_version(10, "StringNormalizer")
2849+
def test_string_lower_flat(self):
2850+
text_val1 = np.array(["a", "Test 1 2 3", "♠♣", "Hi there", "test test", "♥♦"], dtype=np.str)
2851+
def func(text1):
2852+
x = tf.strings.lower(text1)
2853+
x_ = tf.identity(x, name=_TFOUTPUT)
2854+
return x_
2855+
self._run_test_case(func, [_OUTPUT], {_INPUT: text_val1})
2856+
2857+
@check_tf_min_version("1.14", "tf.strings.lower")
2858+
@check_opset_min_version(10, "StringNormalizer")
2859+
def test_string_upper(self):
2860+
text_val1 = np.array([["a", "Test 1 2 3", "♠♣"], ["Hi there", "test test", "♥♦"]], dtype=np.str)
2861+
def func(text1):
2862+
x = tf.strings.upper(text1)
2863+
x_ = tf.identity(x, name=_TFOUTPUT)
2864+
return x_
2865+
self._run_test_case(func, [_OUTPUT], {_INPUT: text_val1})
2866+
28372867
@check_opset_min_version(6, "cast")
28382868
def test_shape_int32(self):
28392869
x_val = np.array([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]], dtype=np.float32)

tests/test_string_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def _run_test_case(self, func, output_names_with_port, feed_dict, **kwargs):
117117

118118
def run_onnxruntime(self, model_path, inputs, output_names):
119119
"""Run test against onnxruntime backend."""
120-
from ortcustomops import get_library_path
120+
from onnxruntime_customops import get_library_path
121121
import onnxruntime as rt
122122
opt = rt.SessionOptions()
123123
opt.register_custom_ops_library(get_library_path())

tf2onnx/custom_opsets/string_ops.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,28 @@ def version_1(cls, ctx, node, **kwargs):
121121
ctx.copy_shape(output_name, not_node.output[0])
122122
ctx.copy_dtype(output_name, not_node.output[0])
123123

124+
@tf_op(["StringLower", "StringUpper"])
125+
class StringLower:
126+
@classmethod
127+
def version_10(cls, ctx, node, **kwargs):
128+
if node.type == "StringLower":
129+
case_action = "LOWER"
130+
else:
131+
case_action = "UPPER"
132+
node.type = "StringNormalizer"
133+
str_input = node.input[0]
134+
rank = ctx.get_rank(node.input[0])
135+
shape = ctx.get_shape(node.input[0])
136+
if rank != 1:
137+
ctx.insert_new_node_on_input(node, "Flatten", node.input[0], axis=0)
138+
node.set_attr("case_change_action", case_action)
139+
if rank != 1:
140+
if shape is None or -1 in shape:
141+
new_shape = ctx.make_node("Shape", [str_input]).output[0]
142+
else:
143+
new_shape = ctx.make_const(utils.make_name("shape"), np.array(shape, np.int64)).output[0]
144+
ctx.insert_new_node_on_output("Reshape", node.output[0], inputs=[node.output[0], new_shape])
145+
124146
@tf_op("SentencepieceOp", domain=constants.CONTRIB_OPS_DOMAIN)
125147
class SentencepieceOp:
126148
@classmethod

0 commit comments

Comments
 (0)