Skip to content

Commit 059afaf

Browse files
Enable string tests in CI pipeline (#1450)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 4e1315a commit 059afaf

File tree

5 files changed

+31
-22
lines changed

5 files changed

+31
-22
lines changed

ci_build/azure_pipelines/templates/setup.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,15 @@ steps:
2525
pip install --index-url https://test.pypi.org/simple/ ort-nightly
2626
fi
2727
28+
if [[ $CI_PLATFORM == "windows" ]] ;
29+
then
30+
pip install -i https://test.pypi.org/simple/ onnxruntime-customops==0.0.1
31+
if [[ $CI_TF_VERSION == 2.* ]] ;
32+
then
33+
pip install tensorflow-text
34+
fi
35+
fi
36+
2837
python setup.py install
2938
pip freeze --all
3039
displayName: 'Setup Environment'

ci_build/azure_pipelines/unit_test.yml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,5 +67,16 @@ stages:
6767
- template: 'unit_test.yml'
6868
report_coverage: 'True'
6969

70+
- template: 'templates/job_generator.yml'
71+
parameters:
72+
python_versions: ['3.7']
73+
platforms: ['windows']
74+
tf_versions: ['2.3.0']
75+
onnx_opsets: ['']
76+
job:
77+
steps:
78+
- template: 'unit_test.yml'
79+
report_coverage: 'True'
80+
7081
- template: 'templates/combine_test_coverage.yml'
7182

tests/backend_test_base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,8 @@ def assert_results_equal(self, expected, actual, rtol, atol, check_value=True, c
109109
for expected_val, actual_val in zip(expected, actual):
110110
if check_value:
111111
if expected_val.dtype == np.object:
112-
decode = np.vectorize(lambda x: x.decode('UTF-8'))
112+
# TFLite pads strings with nul bytes
113+
decode = np.vectorize(lambda x: x.replace(b'\x00', b'').decode('UTF-8'))
113114
expected_val_str = decode(expected_val)
114115
self.assertAllEqual(expected_val_str, actual_val)
115116
else:

tests/test_string_ops.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import tensorflow as tf
1414

1515
from backend_test_base import Tf2OnnxBackendTestBase
16-
from common import requires_custom_ops
16+
from common import requires_custom_ops, check_tf_min_version, check_opset_min_version
1717
from tf2onnx import utils
1818
from tf2onnx import constants
1919

@@ -47,6 +47,7 @@ def func(text):
4747
self._run_test_case(func, [_OUTPUT], {_INPUT: text_val})
4848

4949
@requires_custom_ops("StringJoin")
50+
@check_opset_min_version(8, "Expand")
5051
def test_string_join(self):
5152
text_val1 = np.array([["a", "Test 1 2 3"], ["Hi there", "test test"]], dtype=np.str)
5253
text_val2 = np.array([["b", "Test 1 2 3"], ["Hi there", "suits ♠♣♥♦"]], dtype=np.str)
@@ -57,6 +58,7 @@ def func(text1, text2, text3):
5758
self._run_test_case(func, [_OUTPUT], {_INPUT: text_val1, _INPUT1: text_val2, _INPUT2: text_val3})
5859

5960
@requires_custom_ops("StringSplit")
61+
@check_tf_min_version("2.0", "result is sparse not ragged in tf1")
6062
def test_string_split(self):
6163
text_val = np.array([["a", "Test 1 2 3"], ["Hi there", "test test"]], dtype=np.str)
6264
def func(text):
@@ -114,6 +116,7 @@ def func(x1, x2):
114116
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2})
115117

116118
@requires_custom_ops("RegexSplitWithOffsets")
119+
@check_tf_min_version("2.0", "tensorflow_text")
117120
def test_regex_split_with_offsets(self):
118121
from tensorflow_text.python.ops.regex_split_ops import (
119122
gen_regex_split_ops as lib_gen_regex_split_ops)
@@ -145,6 +148,8 @@ def run_onnxruntime(self, model_path, inputs, output_names):
145148
return results
146149

147150
@requires_custom_ops("WordpieceTokenizer")
151+
@check_tf_min_version("2.0", "tensorflow_text")
152+
@unittest.skip("Not fixed yet")
148153
def test_wordpiece_tokenizer(self):
149154
from tensorflow_text.python.ops.wordpiece_tokenizer import (
150155
gen_wordpiece_tokenizer as lib_gen_wordpiece_tokenizer)

tf2onnx/custom_opsets/string_ops.py

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
@tf_op(["StringSplit", "StringSplitV2"], domain=constants.CONTRIB_OPS_DOMAIN)
2020
class StringOps:
2121
@classmethod
22-
def any_version(cls, opset, ctx, node, **kwargs):
22+
def version_1(cls, ctx, node, **kwargs):
2323
if node.type == "StringSplit":
2424
skip_empty = node.get_attr_value('skip_empty', True)
2525
else:
@@ -33,15 +33,6 @@ def any_version(cls, opset, ctx, node, **kwargs):
3333
skip_empty_const = ctx.make_const(utils.make_name('skip_empty_const'), np.array([skip_empty], np.bool))
3434
ctx.replace_inputs(node, [node.input[0], unsqueeze_node.output[0], skip_empty_const.output[0]])
3535

36-
@classmethod
37-
def version_1(cls, ctx, node, **kwargs):
38-
cls.any_version(1, ctx, node, **kwargs)
39-
40-
@classmethod
41-
def version_13(cls, ctx, node, **kwargs):
42-
cls.any_version(13, ctx, node, **kwargs)
43-
44-
4536
@tf_op("StringToHashBucketFast", domain=constants.CONTRIB_OPS_DOMAIN)
4637
class StringToHashBucketFast:
4738
@classmethod
@@ -72,7 +63,7 @@ def version_1(cls, ctx, node, **kwargs):
7263
@tf_op("StringJoin", domain=constants.CONTRIB_OPS_DOMAIN)
7364
class StringJoin:
7465
@classmethod
75-
def any_version(cls, opset, ctx, node, **kwargs):
66+
def version_1(cls, ctx, node, **kwargs):
7667
node.domain = constants.CONTRIB_OPS_DOMAIN
7768
separator = node.get_attr_value("separator")
7869
if separator is None:
@@ -87,22 +78,14 @@ def any_version(cls, opset, ctx, node, **kwargs):
8778
unsqueezes = []
8879
for inp in node.input:
8980
if ctx.get_shape(inp) == [] and shape_node is not None:
81+
utils.make_sure(ctx.opset >= 8, "Opset 8 required for Expand node for StringJoin")
9082
expand_node = ctx.make_node("Expand", [inp, shape_node.output[0]])
9183
inp = expand_node.output[0]
9284
unsqueeze_node = GraphBuilder(ctx).make_unsqueeze({'data': inp, 'axes': [0]})
9385
unsqueezes.append(unsqueeze_node)
9486
stack_node = ctx.make_node("Concat", unsqueezes, attr={'axis': 0})
9587
ctx.replace_inputs(node, [stack_node.output[0], separator_node.output[0], axis_node.output[0]])
9688

97-
@classmethod
98-
def version_1(cls, ctx, node, **kwargs):
99-
cls.any_version(1, ctx, node, **kwargs)
100-
101-
@classmethod
102-
def version_13(cls, ctx, node, **kwargs):
103-
cls.any_version(13, ctx, node, **kwargs)
104-
105-
10689
@tf_op(["Equal", "NotEqual"], domain=constants.CONTRIB_OPS_DOMAIN)
10790
class StringEqual:
10891
@classmethod

0 commit comments

Comments
 (0)