Skip to content

Commit d536f04

Browse files
hwangdeyufatcat-z
andauthored
Add tf custom op conversion example (#1878)
* add tf custom op example Signed-off-by: hwangdeyu <[email protected]> * fix name and remove unused code comment Signed-off-by: hwangdeyu <[email protected]> Co-authored-by: fatcat-z <[email protected]>
1 parent 8f2e84b commit d536f04

File tree

5 files changed

+101
-2
lines changed

5 files changed

+101
-2
lines changed

examples/tf_custom_op/add_one.cc

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*/
4+
5+
#include "tensorflow/core/framework/op.h"
6+
#include "tensorflow/core/framework/shape_inference.h"
7+
8+
using namespace tensorflow;
9+
10+
11+
// opregister
12+
REGISTER_OP("AddOne")
13+
.Input("add_one: int32")
14+
.Output("result: int32")
15+
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext *c) {
16+
c->set_output(0, c->input(0));
17+
return Status::OK();
18+
});
19+
20+
21+
// keneldefinition
22+
#include "tensorflow/core/framework/op_kernel.h"
23+
24+
class AddOneOp : public OpKernel {
25+
public:
26+
explicit AddOneOp(OpKernelConstruction* context) : OpKernel(context) {}
27+
28+
void Compute(OpKernelContext* context) override {
29+
// Tensor in input
30+
const Tensor& input_tensor = context->input(0);
31+
auto input = input_tensor.flat<int32>();
32+
33+
// Tensor in output
34+
Tensor* output_tensor = NULL;
35+
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor));
36+
auto output = output_tensor->flat<int32>();
37+
38+
const int N = input.size();
39+
for (int i = 0; i < N; i++) {
40+
output(i) += 1;
41+
}
42+
}
43+
};
44+
45+
46+
REGISTER_KERNEL_BUILDER(Name("AddOne").Device(DEVICE_CPU), AddOneOp);

examples/tf_custom_op/add_one.so

30.2 KB
Binary file not shown.
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
4+
import numpy as np
5+
import tensorflow as tf
6+
import tf2onnx
7+
import onnx
8+
import os
9+
from tf2onnx import utils
10+
from tf2onnx.handler import tf_op
11+
from tf2onnx.tf_loader import tf_placeholder
12+
13+
14+
DIR_PATH = os.path.realpath(os.path.dirname(__file__))
15+
saved_model_path = os.path.join(DIR_PATH, "model.onnx")
16+
tf_library_path = os.path.join(DIR_PATH, "add_one.so")
17+
18+
19+
@tf_op("AddOne", onnx_op="Add")
20+
class AddOne:
21+
@classmethod
22+
def version_1(cls, ctx, node, **kwargs):
23+
node_shape = ctx.get_shape(node.input[0])
24+
const_one = ctx.make_const(utils.make_name("const_one"), np.ones(node_shape, dtype = np.int32)).output[0]
25+
node.input.append(const_one)
26+
27+
28+
with tf.compat.v1.Session() as sess:
29+
x = tf_placeholder(tf.int32, [2, 3], name="input")
30+
AddOne = tf.load_op_library(tf_library_path)
31+
x_ = AddOne.add_one(x)
32+
_ = tf.identity(x_, name="output")
33+
34+
onnx_graph = tf2onnx.tfonnx.process_tf_graph(sess.graph,
35+
input_names=["input:0"],
36+
output_names=["output:0"])
37+
model_proto = onnx_graph.make_model("test")
38+
with open(saved_model_path, "wb") as f:
39+
f.write(model_proto.SerializeToString())
40+
41+
onnx_model = onnx.load(saved_model_path)
42+
onnx.checker.check_model(onnx_model)
43+
44+
45+
46+
## Run the model in ONNXRuntime to verify the result.
47+
import onnxruntime as ort
48+
input = np.arange(6).reshape(2,3).astype(np.int32)
49+
ort_session = ort.InferenceSession(saved_model_path)
50+
ort_inputs = {ort_session.get_inputs()[0].name: input}
51+
52+
ort_outs = ort_session.run(None, ort_inputs)
53+
print("input:", input, "\nort_outs:", ort_outs)

tf2onnx/graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -676,7 +676,7 @@ def make_node(self, op_type, inputs, attr=None, output_count=1, outputs=None, sk
676676
return node
677677

678678
def append_node(self, node):
679-
"Add a node to the graph."
679+
"""Add a node to the graph."""
680680
output_shapes = node.output_shapes
681681
output_dtypes = node.output_dtypes
682682
node.graph = self

tf2onnx/tfonnx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,7 @@ def process_parsed_graph(g, custom_op_handlers, inputs_as_nchw, continue_on_erro
518518
# or override existing ops with a custom op.
519519
if custom_op_handlers is not None:
520520
# below is a bit tricky since there are a few api's:
521-
# 1. the future way we want custom ops to be registered with the @tf_op decorator. THose handlers will be
521+
# 1. the future way we want custom ops to be registered with the @tf_op decorator. Those handlers will be
522522
# registered via the decorator on load of the module ... nothing is required here.
523523
# 2. the old custom op api: a dictionary of {name: (func, args[])
524524
# We deal with this by using a compat_handler that wraps to old handler with a new style handler.

0 commit comments

Comments
 (0)