Skip to content

Commit d480f11

Browse files
Add conversion for SentencepieceTokenizeOp (#1309)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent b18067c commit d480f11

File tree

4 files changed

+56
-0
lines changed

4 files changed

+56
-0
lines changed

tests/completed_perf_testing_models.yaml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3359,3 +3359,22 @@ covid-twitter-bert:
33593359
atol: 0.0005
33603360
tag: "serve"
33613361
signature_def: "serving_default"
3362+
3363+
universal-sentence-encoder-multilingual:
3364+
disabled: false
3365+
skip_conversion: false
3366+
model: "C:/Users/tomwi/Documents/tfhubmodels/universal-sentence-encoder-multilingual"
3367+
converted_model: "C:/Users/tomwi/Documents/tfhubmodels/universal-sentence-encoder-multilingual/model.onnx"
3368+
model_type: saved_model
3369+
large_model: false
3370+
run_tf_frozen: false
3371+
use_custom_ops: true
3372+
input_get: get_sentences
3373+
inputs:
3374+
"inputs:0": [100]
3375+
outputs:
3376+
- Identity:0
3377+
rtol: 0.05
3378+
atol: 0.0005
3379+
tag: "serve"
3380+
signature_def: "serving_default"

tests/run_pretrained_models.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@
4040
# not needed for tf-2.0
4141
pass
4242

43+
try:
44+
import tensorflow_text # pylint: disable=unused-import
45+
except ModuleNotFoundError:
46+
pass
47+
4348
from tf2onnx import tf_loader, logging, optimizer, utils, tf_utils, constants
4449
from tf2onnx.tfonnx import process_tf_graph
4550
from tf2onnx.tf_loader import tf_session, tf_reset_default_graph

tf2onnx/convert.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,12 @@ def main():
143143
if using_tf_opset:
144144
extra_opset.append(constants.TENSORFLOW_OPSET)
145145

146+
if any(opset.domain == constants.CONTRIB_OPS_DOMAIN for opset in extra_opset):
147+
try:
148+
import tensorflow_text # pylint: disable=import-outside-toplevel
149+
except ModuleNotFoundError:
150+
logger.warning("tensorflow_text not installed. Model will fail to load if tensorflow_text ops are used.")
151+
146152
# get the frozen tensorflow model from graphdef, checkpoint or saved_model.
147153
if args.graphdef:
148154
graph_def, inputs, outputs = tf_loader.from_graphdef(args.graphdef, args.inputs, args.outputs)

tf2onnx/custom_opsets/string_ops.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,3 +120,29 @@ def version_1(cls, ctx, node, **kwargs):
120120
not_node = ctx.insert_new_node_on_output("Not", output_name, name=utils.make_name(node.name))
121121
ctx.copy_shape(output_name, not_node.output[0])
122122
ctx.copy_dtype(output_name, not_node.output[0])
123+
124+
@tf_op("SentencepieceOp", domain=constants.CONTRIB_OPS_DOMAIN)
125+
class SentencepieceOp:
126+
@classmethod
127+
def version_1(cls, ctx, node, **kwargs):
128+
# This op will be removed when its consumer is converted
129+
pass
130+
131+
@tf_op("SentencepieceTokenizeOp", domain=constants.CONTRIB_OPS_DOMAIN)
132+
class SentencepieceTokenizeOp:
133+
@classmethod
134+
def version_1(cls, ctx, node, **kwargs):
135+
node.domain = constants.CONTRIB_OPS_DOMAIN
136+
input_node = node.inputs[0]
137+
utils.make_sure(input_node.type == "SentencepieceOp", "Input 0 to node %s is not SentencepieceOp", node.name)
138+
ctx.remove_input(node, node.input[0], 0)
139+
140+
nbest_size_cast = ctx.make_node("Cast", [node.input[1]], attr={'to': TensorProto.INT64}).output[0]
141+
ctx.replace_input(node, node.input[1], nbest_size_cast, 1)
142+
for i in range(1, len(node.input)):
143+
unsqueeze = GraphBuilder(ctx).make_unsqueeze({'data': node.input[i], 'axes': [0]})
144+
ctx.replace_input(node, node.input[i], unsqueeze, i)
145+
node.set_attr("model", input_node.attr['model'].s)
146+
node.type = "SentencepieceTokenizer"
147+
if ctx.is_safe_to_remove_nodes([input_node]):
148+
ctx.remove_node(input_node.name)

0 commit comments

Comments
 (0)