Skip to content

Commit 2d29fbd

Browse files
authored
Change custom op example to new API and add the doc (#1883)
* change example to new api and add the doc Signed-off-by: Deyu Huang <[email protected]> * add more details and changes Signed-off-by: Deyu Huang <[email protected]> * fix typo Signed-off-by: Deyu Huang <[email protected]>
1 parent 18ebae1 commit 2d29fbd

File tree

7 files changed

+328
-99
lines changed

7 files changed

+328
-99
lines changed

examples/tf_custom_op/add_one.cc

Lines changed: 0 additions & 46 deletions
This file was deleted.

examples/tf_custom_op/add_one.so

-30.2 KB
Binary file not shown.

examples/tf_custom_op/addone_custom_op.py

Lines changed: 0 additions & 53 deletions
This file was deleted.

examples/tf_custom_op/custom_op.md

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
<!--- SPDX-License-Identifier: Apache-2.0 -->
2+
3+
## Example of converting TensorFlow model with custom op to ONNX
4+
5+
This document describes the ways for exporting TensorFlow model with a custom operator, exporting the operator to ONNX format, and adding the operator to ONNX Runtime for model inference. Tensorflow provides abundant set of operators, and also provides the extending implmentation to register as the new operators. The new custom operators are usually not recognized by tf2onnx conversion and onnxruntime. So the TensorFlow custom ops should be exported using a combination of existing and/or new custom ONNX ops. Once the operator is converted to ONNX format, users can implement and register it with ONNX Runtime for model inference. This document explains the details of this process end-to-end, along with an example.
6+
7+
8+
### Required Steps
9+
10+
- [1](#step1) - Adding the Tensorflow custom operator implementation in C++ and registering it with TensorFlow
11+
- [2](#step2) - Exporting the custom Operator to ONNX, using:
12+
<br /> - a combination of existing ONNX ops
13+
<br /> or
14+
<br /> - a custom ONNX Operator
15+
- [3](#step3) - Adding the custom operator implementation and registering it in ONNX Runtime (required only if using a custom ONNX op in step 2)
16+
17+
18+
### Implement the Custom Operator
19+
Firstly, try to install the TensorFlow latest version (Nighly is better) build refer to [here](https://github.com/tensorflow/tensorflow#install). And then implement the custom operators saving in TensorFlow library format and the file usually ends with `.so`. We have a simple example of `AddOne`, which is adding one for a tensor.
20+
21+
22+
#### Define the op interface
23+
Specify the name of your op, its inputs (types and names) and outputs (types and names), as well as docstrings and any attrs the op might require.
24+
```
25+
#include "tensorflow/core/framework/op.h"
26+
#include "tensorflow/core/framework/shape_inference.h"
27+
#include "tensorflow/core/framework/register_types.h"
28+
29+
using namespace tensorflow;
30+
31+
32+
// opregister
33+
REGISTER_OP("DoubleAndAddOne")
34+
.Input("x: T")
35+
.Output("result: T")
36+
.Attr("T: {float, double, int32}")
37+
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext *c) {
38+
::tensorflow::shape_inference::ShapeHandle shape_x = c->input(0);
39+
if (!c->RankKnown(shape_x)) {
40+
c->set_output(0, c->UnknownShape());
41+
return Status::OK();
42+
}
43+
c->set_output(0, shape_x);
44+
return Status::OK();
45+
})
46+
.Doc(R"doc(
47+
Calculate the value 2x + 1.
48+
x: A Tensor `Tensor`. Must be one of the types in `T`.
49+
50+
Returns: A `Tensor`. Has the same type with `x`.
51+
)doc");
52+
```
53+
54+
#### Implement the op kernel
55+
Create a class that extends `OpKernel` and overrides the `Compute()` method. The implementation is written to the function `Compute()`.
56+
57+
```
58+
#include "tensorflow/core/framework/op_kernel.h"
59+
60+
template <typename T>
61+
class DoubleAndAddOneOp : public OpKernel {
62+
public:
63+
explicit DoubleAndAddOneOp(OpKernelConstruction* context) : OpKernel(context) {}
64+
65+
void Compute(OpKernelContext* context) override {
66+
// Grab the input tensor
67+
const Tensor& input_tensor = context->input(0);
68+
auto input = input_tensor.flat<T>();
69+
70+
// Tensor in output
71+
Tensor* output_tensor = NULL;
72+
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor));
73+
auto output = output_tensor->flat<T>();
74+
75+
const int N = input.size();
76+
for (int i = 0; i < N; i++) {
77+
output(i) = output(i) * T(2) + T(1);
78+
}
79+
}
80+
};
81+
```
82+
Add the Register kernel build,
83+
```
84+
REGISTER_KERNEL_BUILDER(Name("DoubleAndAddOne")
85+
.Device(DEVICE_CPU)
86+
.TypeConstraint<int>("T"),
87+
DoubleAndAddOneOp<int>);
88+
```
89+
Save below code in C++ `.cc` file,
90+
91+
#### Using C++ compiler to compile the op
92+
Assuming you have g++ installed, here is the sequence of commands you can use to compile your op into a dynamic library.
93+
```
94+
TF_CFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))') )
95+
TF_LFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))') )
96+
g++ -std=c++14 -shared double_and_add_one.cc -o double_and_add_one.so -fPIC ${TF_CFLAGS[@]} ${TF_LFLAGS[@]} -O2
97+
```
98+
After below steps, we can get a TensorFlow custom op library `double_and_add_one.so`.
99+
100+
101+
### Convert the Operator to ONNX
102+
To be able to use this custom ONNX operator for inference, we need to add our custom operator to an inference engine. If the operator can be conbinded with exsiting [ONNX standard operators](https://github.com/onnx/onnx/blob/main/docs/Operators.md). The case will be easier:
103+
104+
1- using [--load_op_libraries](https://github.com/onnx/tensorflow-onnx#--load_op_libraries) in conversion command or `tf.load_op_library()` method in code to load the TensorFlow custom ops library.
105+
106+
2- implement the op handler, registered it with the `@tf_op` decorator. Those handlers will be registered via the decorator on load of the module. [Here](https://github.com/onnx/tensorflow-onnx/tree/main/tf2onnx/onnx_opset) are examples of TensorFlow op hander implementations.
107+
108+
```
109+
import numpy as np
110+
import tensorflow as tf
111+
import tf2onnx
112+
import onnx
113+
import os
114+
from tf2onnx import utils
115+
from tf2onnx.handler import tf_op
116+
117+
118+
DIR_PATH = os.path.realpath(os.path.dirname(__file__))
119+
saved_model_path = os.path.join(DIR_PATH, "model.onnx")
120+
tf_library_path = os.path.join(DIR_PATH, "double_and_add_one.so")
121+
122+
123+
@tf_op("DoubleAndAddOne")
124+
class DoubleAndAddOne:
125+
@classmethod
126+
def version_1(cls, ctx, node, **kwargs):
127+
node.type = "Mul"
128+
node_shape = ctx.get_shape(node.input[0])
129+
node_dtype = ctx.get_dtype(node.input[0])
130+
node_np_dtype = utils.map_onnx_to_numpy_type(node_dtype)
131+
132+
const_two = ctx.make_const(utils.make_name("const_two"), np.array([2]).astype(node_np_dtype)).output[0]
133+
node.input.append(const_two)
134+
135+
const_one = ctx.make_const(utils.make_name("const_one"), np.ones(node_shape, dtype=node_np_dtype)).output[0]
136+
op_name = utils.make_name(node.name)
137+
ctx.insert_new_node_on_output("Add", node.output[0], inputs=[node.output[0], const_one], name=op_name)
138+
139+
140+
@tf.function
141+
def func(x):
142+
custom_op = tf.load_op_library(tf_library_path)
143+
x_ = custom_op.double_and_add_one(x)
144+
output = tf.identity(x_, name="output")
145+
return output
146+
147+
spec = [tf.TensorSpec(shape=(2, 3), dtype=tf.int32, name="input")]
148+
149+
onnx_model, _ = tf2onnx.convert.from_function(func, input_signature=spec, opset=15)
150+
151+
with open(saved_model_path, "wb") as f:
152+
f.write(onnx_model.SerializeToString())
153+
154+
onnx_model = onnx.load(saved_model_path)
155+
onnx.checker.check_model(onnx_model)
156+
```
157+
158+
3- Run in ONNXRuntime, using `InferenceSession` to do inference and get the result.
159+
```
160+
import onnxruntime as ort
161+
input = np.arange(6).reshape(2,3).astype(np.int32)
162+
ort_session = ort.InferenceSession(saved_model_path)
163+
ort_inputs = {ort_session.get_inputs()[0].name: input}
164+
165+
ort_outs = ort_session.run(None, ort_inputs)
166+
print("input:", input, "\nAddOne ort_outs:", ort_outs)
167+
```
168+
169+
170+
If the operator can not using existing ONNX standard operators only, you need to go to [implement the operator in ONNX Runtime](https://github.com/onnx/tutorials/blob/master/PyTorchCustomOperator/README.md#implement-the-operator-in-onnx-runtime).
171+
172+
### References:
173+
1- [Create an custom TensorFlow op](https://www.tensorflow.org/guide/create_op)
174+
175+
2- [ONNX Runtime: Adding a New Op](https://onnxruntime.ai/docs/reference/operators/add-custom-op.html#register-a-custom-operator)
176+
177+
3- [PyTorch Custom Operators export to ONNX](https://github.com/onnx/tutorials/blob/master/PyTorchCustomOperator/README.md)
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*/
4+
5+
#include "tensorflow/core/framework/op.h"
6+
#include "tensorflow/core/framework/shape_inference.h"
7+
#include "tensorflow/core/framework/op_kernel.h"
8+
#include "tensorflow/core/framework/register_types.h"
9+
10+
using namespace tensorflow;
11+
12+
13+
// opregister
14+
REGISTER_OP("DoubleAndAddOne")
15+
.Input("x: T")
16+
.Output("result: T")
17+
.Attr("T: {float, double, int32}")
18+
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext *c) {
19+
::tensorflow::shape_inference::ShapeHandle shape_x = c->input(0);
20+
if (!c->RankKnown(shape_x)) {
21+
c->set_output(0, c->UnknownShape());
22+
return Status::OK();
23+
}
24+
c->set_output(0, shape_x);
25+
return Status::OK();
26+
})
27+
.Doc(R"doc(
28+
Calculate the value 2x + 1.
29+
x: A Tensor `Tensor`. Must be one of the types in `T`.
30+
31+
Returns: A `Tensor`. Has the same type with `x`.
32+
)doc");
33+
34+
35+
// keneldefinition
36+
template <typename T>
37+
class DoubleAndAddOneOp : public OpKernel {
38+
public:
39+
explicit DoubleAndAddOneOp(OpKernelConstruction* context) : OpKernel(context) {}
40+
41+
void Compute(OpKernelContext* context) override {
42+
// Grab the input tensor
43+
const Tensor& input_tensor = context->input(0);
44+
auto input = input_tensor.flat<T>();
45+
46+
// Tensor in output
47+
Tensor* output_tensor = NULL;
48+
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor));
49+
auto output = output_tensor->flat<T>();
50+
51+
const int N = input.size();
52+
for (int i = 0; i < N; i++) {
53+
output(i) = output(i) * T(2) + T(1);
54+
}
55+
}
56+
};
57+
58+
59+
REGISTER_KERNEL_BUILDER(Name("DoubleAndAddOne")
60+
.Device(DEVICE_CPU)
61+
.TypeConstraint<float>("T"),
62+
DoubleAndAddOneOp<float>);
63+
REGISTER_KERNEL_BUILDER(Name("DoubleAndAddOne")
64+
.Device(DEVICE_CPU)
65+
.TypeConstraint<double>("T"),
66+
DoubleAndAddOneOp<double>);
67+
REGISTER_KERNEL_BUILDER(Name("DoubleAndAddOne")
68+
.Device(DEVICE_CPU)
69+
.TypeConstraint<int>("T"),
70+
DoubleAndAddOneOp<int>);
71+
72+
73+
#define REGISTER_KERNEL(type) \
74+
REGISTER_KERNEL_BUILDER( \
75+
Name("DoubleAndAddOne").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
76+
DoubleAndAddOneOp<type>)
77+
78+
REGISTER_KERNEL(float);
79+
REGISTER_KERNEL(double);
80+
REGISTER_KERNEL(int);
81+
82+
TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
83+
#undef REGISTER_KERNEL
84+
70.5 KB
Binary file not shown.

0 commit comments

Comments
 (0)