Skip to content

Commit 9bb355b

Browse files
C++ ORT (#3)
1 parent b4d0691 commit 9bb355b

File tree

16 files changed

+1385
-0
lines changed

16 files changed

+1385
-0
lines changed

.gitignore

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,11 @@
33
*.ckpt
44
*.onnx
55
*.tar.gz
6+
*.o
7+
*.a
8+
*.so
9+
*.out
10+
*.bak
11+
*.pb.*
12+
core.*
613
__pycache__

C++/.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
client
2+
infer_server

C++/Makefile

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
CXX = g++
2+
PROTOC = protoc
3+
BRPC_PATH= ~/github/incubator-brpc/
4+
HDRS+=$(BRPC_PATH)/output/include
5+
LIBS+=$(BRPC_PATH)/output/lib
6+
CXXFLAGS = -std=c++11 -pthread -O2 -fPIC -fno-omit-frame-pointer -I $(HDRS) -L $(LIBS) -lprotobuf -lgflags -lbrpc
7+
CXXFLAGS += -I ~/local/onnxruntime/include -L ~/local/onnxruntime/lib -lonnxruntime -I ~/local -I ~/local/include -lutf8proc -L ~/local/lib -I .
8+
BIN = infer_server client
9+
PROTOS = $(wildcard *.proto)
10+
PROTO_OBJS = $(PROTOS:.proto=.pb.o)
11+
MODEL_OBJ = model.o
12+
TOKEN_OBJ = tokenization.o
13+
14+
ALL: $(BIN)
15+
client: client.cpp $(PROTO_OBJS)
16+
@echo "> Linking $@"
17+
$(CXX) $(CXXFLAGS) $^ -o $@
18+
19+
infer_server: server.cpp $(PROTO_OBJS) $(MODEL_OBJ) $(TOKEN_OBJ)
20+
@echo "> Linking $@"
21+
$(CXX) $(CXXFLAGS) $^ -o $@
22+
23+
%.pb.cc %.pb.h: %.proto
24+
@echo "> Generating $@"
25+
$(PROTOC) --cpp_out=. --proto_path=. $<
26+
27+
%.o: %.cc
28+
@echo "> Compiling $@"
29+
$(CXX) -c $(CXXFLAGS) $< -o $@
30+
31+
%.o: %.cpp
32+
@echo "> Compiling $@"
33+
$(CXX) -c $(CXXFLAGS) $< -o $@
34+
35+
%.o: util/%.cpp
36+
@echo "> Compiling $@"
37+
$(CXX) -c $(CXXFLAGS) $< -o $@
38+
39+
clean:
40+
rm -rf *.o *.pb.* $(BIN)

C++/README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# 依赖
2+
[boost](https://www.boost.org/)[utf8proc](https://github.com/guodongxiaren/utf8proc)
3+
强烈不建议使用Github上面的boost项的Release(缺少submodule)
4+
5+
g++ token.cpp -std=c++11 -I ~/local/ -I ~/local/include -L ~/local/lib/ -lutf8proc
6+
export LD_LIBRARY_PATH=~/local/lib:$LB_LIBRARY_PATH
7+
8+
g++ ort_pred.cpp -I ~/local/onnxruntime/include --std=c++11 -L ~/local/onnxruntime/lib -lonnxruntime

C++/a.cpp

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
#include <iostream>
2+
#include <algorithm>
3+
#include <vector>
4+
#include <chrono>
5+
#include <string>
6+
#include <vector>
7+
#include <onnxruntime_cxx_api.h>
8+
#include "util/tokenization.h"
9+
10+
using namespace std;
11+
12+
const static std::vector<std::string> key = {
13+
"finance",
14+
"realty",
15+
"stocks",
16+
"education",
17+
"science",
18+
"society",
19+
"politics",
20+
"sports",
21+
"game",
22+
"entertainment"
23+
};
24+
25+
template <typename T>
26+
int argmax(const std::vector<T>& v) {
27+
if (v.empty()) {
28+
return -1;
29+
}
30+
return std::max_element(v.begin(), v.end()) - v.begin();
31+
}
32+
template <typename T>
33+
int argmax(T a, T b) {
34+
return std::max_element(a, b) - a;
35+
}
36+
class Model {
37+
public:
38+
Model(const std::string& vocab_path) {
39+
tokenizer_ = new FullTokenizer(vocab_path);
40+
}
41+
42+
std::vector<std::vector<int64_t>> build_input(const std::string& text) {
43+
auto tokens = tokenizer_->tokenize(text);
44+
auto token_ids = tokenizer_->convertTokensToIds(tokens);
45+
46+
std::vector<std::vector<int64_t>> res;
47+
48+
std::vector<int64_t> input(32);
49+
std::vector<int64_t> mask(32);
50+
input[0] = 101;
51+
mask[0] = 1;
52+
for (int i = 0; i < token_ids.size() && i < 31; ++i) {
53+
input[i+1] = token_ids[i];
54+
mask[i+1] = token_ids[i] > 0;
55+
}
56+
res.push_back(std::move(input));
57+
res.push_back(std::move(mask));
58+
return res;
59+
}
60+
FullTokenizer* tokenizer_ = nullptr;
61+
};
62+
63+
int main()
64+
{
65+
const char* text = "李稻葵:过去2年抗疫为每人增寿10天";
66+
const char* vocab_path = "/home/guodong/bert_pretrain/vocab.txt";
67+
Model model(vocab_path);
68+
auto res = model.build_input(text);
69+
70+
Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "test");
71+
Ort::SessionOptions session_options;
72+
73+
OrtCUDAProviderOptions cuda_options; //= {
74+
// 0,
75+
// //OrtCudnnConvAlgoSearch::EXHAUSTIVE,
76+
// OrtCudnnConvAlgoSearchExhaustive,
77+
// std::numeric_limits<size_t>::max(),
78+
// 0,
79+
// true
80+
// };
81+
82+
session_options.AppendExecutionProvider_CUDA(cuda_options);
83+
const char* model_path = "/home/guodong/github/Bert-Chinese-Text-Classification-Pytorch/model.onnx";
84+
85+
86+
Ort::Session session(env, model_path, session_options);
87+
// print model input layer (node names, types, shape etc.)
88+
Ort::AllocatorWithDefaultOptions allocator;
89+
90+
// print number of model input nodes
91+
size_t num_input_nodes = session.GetInputCount();
92+
std::cout<< num_input_nodes <<std::endl;
93+
std::cout<< session.GetOutputCount() <<std::endl;
94+
95+
std::vector<int64_t> input_node_dims = {1, 32};
96+
97+
auto& input_tensor_values = res[0];
98+
auto& mask_tensor_values = res[1];
99+
100+
//size_t input_tensor_size = 32;
101+
for (auto i : input_tensor_values) {
102+
std::cout << i << "\t" ;
103+
}
104+
std::cout<<std::endl;
105+
for (auto i : mask_tensor_values) {
106+
std::cout << i << "\t" ;
107+
}
108+
std::cout<<std::endl;
109+
110+
// create input tensor object from data values !!!!!!!!!!
111+
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
112+
113+
Ort::Value input_tensor = Ort::Value::CreateTensor<int64_t>(memory_info, input_tensor_values.data(),
114+
input_tensor_values.size(), input_node_dims.data(), 2);
115+
116+
Ort::Value mask_tensor = Ort::Value::CreateTensor<int64_t>(memory_info, mask_tensor_values.data(),
117+
mask_tensor_values.size(), input_node_dims.data(), 2);
118+
119+
std::vector<Ort::Value> ort_inputs;
120+
ort_inputs.push_back(std::move(input_tensor));
121+
ort_inputs.push_back(std::move(mask_tensor));
122+
123+
std::vector<const char*> input_node_names = {"ids", "mask"};
124+
std::vector<const char*> output_node_names = {"output"};
125+
auto output_tensors = session.Run(Ort::RunOptions{nullptr}, input_node_names.data(), ort_inputs.data(),
126+
ort_inputs.size(), output_node_names.data(), 1);
127+
128+
float* floatarr = output_tensors[0].GetTensorMutableData<float>();
129+
130+
for (int i=0; i<10; i++)
131+
{
132+
std::cout<<floatarr[i]<<std::endl;
133+
}
134+
std::cout<< key[argmax(floatarr, floatarr+10)] << std::endl;
135+
136+
return 0;
137+
}

C++/client.cpp

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
//
2+
#include <gflags/gflags.h>
3+
#include <butil/logging.h>
4+
#include <butil/time.h>
5+
#include <brpc/channel.h>
6+
#include "infer.pb.h"
7+
8+
DEFINE_string(attachment, "", "Carry this along with requests");
9+
DEFINE_string(protocol, "baidu_std", "Protocol type. Defined in src/brpc/options.proto");
10+
DEFINE_string(connection_type, "", "Connection type. Available values: single, pooled, short");
11+
DEFINE_string(server, "0.0.0.0:8000", "IP Address of server");
12+
DEFINE_string(load_balancer, "", "The algorithm for load balancing");
13+
DEFINE_int32(timeout_ms, 100, "RPC timeout in milliseconds");
14+
DEFINE_int32(max_retry, 3, "Max retries(not including the first RPC)");
15+
DEFINE_int32(interval_ms, 1000, "Milliseconds between consecutive requests");
16+
17+
int main(int argc, char* argv[]) {
18+
// Parse gflags. We recommend you to use gflags as well.
19+
//GFLAGS_NS::ParseCommandLineFlags(&argc, &argv, true);
20+
gflags::ParseCommandLineFlags(&argc, &argv, true);
21+
22+
// A Channel represents a communication line to a Server. Notice that
23+
// Channel is thread-safe and can be shared by all threads in your program.
24+
brpc::Channel channel;
25+
26+
// Initialize the channel, NULL means using default options.
27+
brpc::ChannelOptions options;
28+
options.protocol = FLAGS_protocol;
29+
options.connection_type = FLAGS_connection_type;
30+
options.timeout_ms = FLAGS_timeout_ms/*milliseconds*/;
31+
options.max_retry = FLAGS_max_retry;
32+
if (channel.Init(FLAGS_server.c_str(), FLAGS_load_balancer.c_str(), &options) != 0) {
33+
LOG(ERROR) << "Fail to initialize channel";
34+
return -1;
35+
}
36+
37+
// Normally, you should not call a Channel directly, but instead construct
38+
// a stub Service wrapping it. stub can be shared by all threads as well.
39+
guodongxiaren::InferService_Stub stub(&channel);
40+
41+
// Send a request and wait for the response every 1 second.
42+
int log_id = 0;
43+
while (!brpc::IsAskedToQuit()) {
44+
// We will receive response synchronously, safe to put variables
45+
// on stack.
46+
guodongxiaren::NewsClassifyRequest request;
47+
guodongxiaren::NewsClassifyResponse response;
48+
brpc::Controller cntl;
49+
50+
request.set_title("李稻葵:过去2年抗疫为每人增寿10天");
51+
52+
cntl.set_log_id(log_id ++); // set by user
53+
// Set attachment which is wired to network directly instead of
54+
// being serialized into protobuf messages.
55+
cntl.request_attachment().append(FLAGS_attachment);
56+
57+
// Because `done'(last parameter) is NULL, this function waits until
58+
// the response comes back or error occurs(including timedout).
59+
stub.NewsClassify(&cntl, &request, &response, NULL);
60+
if (!cntl.Failed()) {
61+
LOG(INFO) << "Received response from " << cntl.remote_side()
62+
<< " to " << cntl.local_side()
63+
<< ": " << response.result() << " (attached="
64+
<< cntl.response_attachment() << ")"
65+
<< " latency=" << cntl.latency_us() << "us";
66+
} else {
67+
LOG(WARNING) << cntl.ErrorText();
68+
}
69+
usleep(FLAGS_interval_ms * 1000L);
70+
}
71+
72+
LOG(INFO) << "NewsClassifyClient is going to quit";
73+
return 0;
74+
}

C++/infer.proto

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
syntax="proto2";
2+
package guodongxiaren;
3+
4+
option cc_generic_services = true;
5+
6+
message NewsClassifyRequest {
7+
required string title = 1;
8+
};
9+
10+
message NewsClassifyResponse {
11+
required string result = 1;
12+
};
13+
14+
service InferService {
15+
rpc NewsClassify(NewsClassifyRequest) returns (NewsClassifyResponse);
16+
};

C++/ort_pred.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#include <iostream>
2+
#include <algorithm>
3+
#include <vector>
4+
#include <chrono>
5+
#include <string>
6+
#include <vector>
7+
#include "util/model.h"
8+
9+
using namespace std;
10+
11+
int main() {
12+
const char* vocab_path = "/home/guodong/bert_pretrain/vocab.txt";
13+
const char* model_path = "/home/guodong/github/Bert-Chinese-Text-Classification-Pytorch/model.onnx";
14+
15+
Model model(model_path, vocab_path);
16+
17+
//const char* text = "李稻葵:过去2年抗疫为每人增寿10天";
18+
//int idx = model.predict(text);
19+
20+
std::string line;
21+
while (std::getline(std::cin, line)) {
22+
auto a = gettimeofday_us();
23+
std::string r = model.predict(line);
24+
auto b = gettimeofday_us();
25+
std::cout << line << " is " << r << " cost:" << (b-a) <<" us" <<std::endl;
26+
}
27+
28+
return 0;
29+
}

0 commit comments

Comments
 (0)