Skip to content

Commit 890a141

Browse files
Update online-websocket-server.cc
1 parent 337d5f7 commit 890a141

File tree

1 file changed

+65
-53
lines changed

1 file changed

+65
-53
lines changed

sherpa-onnx/csrc/online-websocket-server.cc

Lines changed: 65 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -30,80 +30,92 @@ Please refer to
3030
for a list of pre-trained models to download.
3131
)";
3232

33-
int32_t main(int32_t argc, char *argv[]) {
34-
sherpa_onnx::ParseOptions po(kUsageMessage);
33+
class OnlineWebsocketServerApp {
34+
public:
35+
OnlineWebsocketServerApp(int32_t argc, char *argv[]) : argc_(argc), argv_(argv) {}
36+
37+
void Run() {
38+
sherpa_onnx::ParseOptions po(kUsageMessage);
39+
sherpa_onnx::OnlineWebsocketServerConfig config;
3540

36-
sherpa_onnx::OnlineWebsocketServerConfig config;
41+
// the server will listen on this port
42+
int32_t port = 6006;
3743

38-
// the server will listen on this port
39-
int32_t port = 6006;
44+
// size of the thread pool for handling network connections
45+
int32_t num_io_threads = 1;
4046

41-
// size of the thread pool for handling network connections
42-
int32_t num_io_threads = 1;
47+
// size of the thread pool for neural network computation and decoding
48+
int32_t num_work_threads = 3;
4349

44-
// size of the thread pool for neural network computation and decoding
45-
int32_t num_work_threads = 3;
50+
po.Register("num-io-threads", &num_io_threads,
51+
"Thread pool size for network connections.");
4652

47-
po.Register("num-io-threads", &num_io_threads,
48-
"Thread pool size for network connections.");
53+
po.Register("num-work-threads", &num_work_threads,
54+
"Thread pool size for for neural network "
55+
"computation and decoding.");
4956

50-
po.Register("num-work-threads", &num_work_threads,
51-
"Thread pool size for for neural network "
52-
"computation and decoding.");
57+
po.Register("port", &port, "The port on which the server will listen.");
5358

54-
po.Register("port", &port, "The port on which the server will listen.");
59+
config.Register(&po);
5560

56-
config.Register(&po);
61+
if (argc_ == 1) {
62+
po.PrintUsage();
63+
exit(EXIT_FAILURE);
64+
}
5765

58-
if (argc == 1) {
59-
po.PrintUsage();
60-
exit(EXIT_FAILURE);
61-
}
66+
po.Read(argc_, argv_);
6267

63-
po.Read(argc, argv);
68+
if (po.NumArgs() != 0) {
69+
SHERPA_ONNX_LOGE("Unrecognized positional arguments!");
70+
po.PrintUsage();
71+
exit(EXIT_FAILURE);
72+
}
6473

65-
if (po.NumArgs() != 0) {
66-
SHERPA_ONNX_LOGE("Unrecognized positional arguments!");
67-
po.PrintUsage();
68-
exit(EXIT_FAILURE);
69-
}
74+
config.Validate();
7075

71-
config.Validate();
76+
asio::io_context io_conn; // for network connections
77+
asio::io_context io_work; // for neural network and decoding
7278

73-
asio::io_context io_conn; // for network connections
74-
asio::io_context io_work; // for neural network and decoding
79+
sherpa_onnx::OnlineWebsocketServer server(io_conn, io_work, config);
80+
server.Run(port);
7581

76-
sherpa_onnx::OnlineWebsocketServer server(io_conn, io_work, config);
77-
server.Run(port);
82+
SHERPA_ONNX_LOGE("Started!");
83+
SHERPA_ONNX_LOGE("Listening on: %d", port);
84+
SHERPA_ONNX_LOGE("Number of work threads: %d", num_work_threads);
7885

79-
SHERPA_ONNX_LOGE("Started!");
80-
SHERPA_ONNX_LOGE("Listening on: %d", port);
81-
SHERPA_ONNX_LOGE("Number of work threads: %d", num_work_threads);
86+
// give some work to do for the io_work pool
87+
auto work_guard = asio::make_work_guard(io_work);
8288

83-
// give some work to do for the io_work pool
84-
auto work_guard = asio::make_work_guard(io_work);
89+
std::vector<std::thread> io_threads;
8590

86-
std::vector<std::thread> io_threads;
91+
// decrement since the main thread is also used for network communications
92+
for (int32_t i = 0; i < num_io_threads - 1; ++i) {
93+
io_threads.emplace_back([&io_conn]() { io_conn.run(); });
94+
}
8795

88-
// decrement since the main thread is also used for network communications
89-
for (int32_t i = 0; i < num_io_threads - 1; ++i) {
90-
io_threads.emplace_back([&io_conn]() { io_conn.run(); });
91-
}
96+
std::vector<std::thread> work_threads;
97+
for (int32_t i = 0; i < num_work_threads; ++i) {
98+
work_threads.emplace_back([&io_work]() { io_work.run(); });
99+
}
92100

93-
std::vector<std::thread> work_threads;
94-
for (int32_t i = 0; i < num_work_threads; ++i) {
95-
work_threads.emplace_back([&io_work]() { io_work.run(); });
96-
}
101+
io_conn.run();
97102

98-
io_conn.run();
103+
for (auto &t : io_threads) {
104+
t.join();
105+
}
99106

100-
for (auto &t : io_threads) {
101-
t.join();
102-
}
107+
for (auto &t : work_threads) {
108+
t.join();
109+
}
110+
}
103111

104-
for (auto &t : work_threads) {
105-
t.join();
106-
}
112+
private:
113+
int32_t argc_;
114+
char **argv_;
115+
};
107116

108-
return 0;
117+
int32_t main(int32_t argc, char *argv[]) {
118+
OnlineWebsocketServerApp app(argc, argv);
119+
app.Run();
120+
return 0;
109121
}

0 commit comments

Comments
 (0)