|
1 |
| -#include "common.h" |
2 |
| -#include "common-whisper.h" |
3 | 1 | #include "whisper.h"
|
4 | 2 | #include "ixwebsocket/IXWebSocketServer.h"
|
5 |
| -#include "ixwebsocket/IXNetSystem.h" |
6 |
| - |
7 | 3 | #include <atomic>
|
8 |
| -#include <chrono> |
9 |
| -#include <condition_variable> |
10 |
| -#include <cstdio> |
11 |
| -#include <fstream> |
12 | 4 | #include <mutex>
|
13 |
| -#include <string> |
14 |
| -#include <thread> |
15 |
| -#include <unordered_map> |
16 |
| -#include <vector> |
| 5 | +#include <queue> |
| 6 | + |
| 7 | +std::mutex g_ctx_mtx; |
| 8 | +whisper_context* g_ctx = nullptr; |
| 9 | +constexpr int CHUNK_SIZE = 3 * 16000; |
17 | 10 |
|
18 | 11 | struct ClientSession {
|
19 | 12 | std::vector<float> pcm_buffer;
|
20 | 13 | std::mutex mtx;
|
21 |
| - std::condition_variable cv; |
22 |
| - std::atomic<bool> active{false}; |
23 |
| - std::atomic<bool> terminate{false}; |
24 |
| -}; |
25 |
| - |
26 |
| -struct server_params { |
27 |
| - int32_t port = 9002; |
28 |
| - int32_t n_threads = std::min(4, (int32_t)std::thread::hardware_concurrency()); |
29 |
| - std::string model = "models/ggml-base.en.bin"; |
30 |
| - bool use_gpu = true; |
| 14 | + std::atomic<bool> active{true}; |
31 | 15 | };
|
32 | 16 |
|
33 | 17 | class WhisperServer {
|
34 | 18 | private:
|
35 |
| - server_params params; |
36 | 19 | ix::WebSocketServer server;
|
37 | 20 | std::unordered_map<std::string, std::unique_ptr<ClientSession>> clients;
|
38 | 21 | std::mutex clients_mtx;
|
39 |
| - whisper_context* ctx = nullptr; |
| 22 | + std::thread processor_thread; |
| 23 | + std::atomic<bool> running{true}; |
40 | 24 |
|
41 | 25 | public:
|
42 |
| - WhisperServer(const server_params& params) : params(params), server(params.port, "0.0.0.0") { |
43 |
| - ix::initNetSystem(); |
44 |
| - |
| 26 | + WhisperServer(int port, const std::string& model_path) : server(port, "0.0.0.0") { |
45 | 27 | whisper_context_params cparams = whisper_context_default_params();
|
46 |
| - cparams.use_gpu = params.use_gpu; |
47 |
| - ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams); |
| 28 | + cparams.use_gpu = true; |
| 29 | + g_ctx = whisper_init_from_file(model_path.c_str()); |
48 | 30 |
|
49 | 31 | server.setTLSOptions({});
|
50 |
| - server.setOnClientMessageCallback([this](std::shared_ptr<ix::ConnectionState> connectionState, ix::WebSocket& webSocket, const ix::WebSocketMessagePtr& msg) { |
51 |
| - this->handleMessage(connectionState, webSocket, msg); |
| 32 | + server.setOnClientMessageCallback([this](auto&&... args) { |
| 33 | + handleMessage(args...); |
52 | 34 | });
|
| 35 | + |
| 36 | + processor_thread = std::thread([this] { processQueues(); }); |
53 | 37 | }
|
54 | 38 |
|
55 | 39 | ~WhisperServer() {
|
| 40 | + running = false; |
56 | 41 | server.stop();
|
57 |
| - ix::uninitNetSystem(); |
58 |
| - if(ctx) whisper_free(ctx); |
| 42 | + if (processor_thread.joinable()) processor_thread.join(); |
| 43 | + whisper_free(g_ctx); |
59 | 44 | }
|
60 | 45 |
|
61 | 46 | void run() {
|
62 | 47 | server.listenAndStart();
|
63 |
| - fprintf(stderr, "Server started on port %d\n", params.port); |
64 |
| - |
65 |
| - while(true) { |
66 |
| - std::this_thread::sleep_for(std::chrono::seconds(1)); |
67 |
| - } |
| 48 | + while (running) std::this_thread::sleep_for(std::chrono::seconds(1)); |
68 | 49 | }
|
69 | 50 |
|
70 | 51 | private:
|
71 |
| - void handleMessage(std::shared_ptr<ix::ConnectionState> connectionState, ix::WebSocket& webSocket, const ix::WebSocketMessagePtr& msg) { |
72 |
| - const auto client_id = connectionState->getId(); |
| 52 | + void handleMessage(std::shared_ptr<ix::ConnectionState> state, |
| 53 | + ix::WebSocket& ws, |
| 54 | + const ix::WebSocketMessagePtr& msg) { |
| 55 | + const std::string client_id = state->getId(); |
73 | 56 |
|
74 |
| - if(msg->type == ix::WebSocketMessageType::Open) { |
75 |
| - fprintf(stderr, "New client connected: %s\n", client_id); |
| 57 | + if (msg->type == ix::WebSocketMessageType::Open) { |
76 | 58 | std::lock_guard<std::mutex> lock(clients_mtx);
|
77 | 59 | clients[client_id] = std::make_unique<ClientSession>();
|
78 |
| - clients[client_id]->active = true; |
79 |
| - std::thread(&WhisperServer::processClientAudio, this, client_id).detach(); |
80 | 60 | }
|
81 |
| - else if(msg->type == ix::WebSocketMessageType::Close) { |
82 |
| - fprintf(stderr, "Client disconnected: %s\n", client_id); |
| 61 | + else if (msg->type == ix::WebSocketMessageType::Close) { |
83 | 62 | std::lock_guard<std::mutex> lock(clients_mtx);
|
84 |
| - if(clients.count(client_id)) { |
85 |
| - clients[client_id]->terminate = true; |
86 |
| - clients[client_id]->cv.notify_one(); |
| 63 | + if (clients.count(client_id)) { |
| 64 | + clients[client_id]->active = false; |
87 | 65 | clients.erase(client_id);
|
88 | 66 | }
|
89 | 67 | }
|
90 |
| - else if(msg->type == ix::WebSocketMessageType::Message) { |
91 |
| - //std::lock_guard<std::mutex> lock(clients_mtx); |
| 68 | + else if (msg->type == ix::WebSocketMessageType::Message && msg->binary) { |
| 69 | + std::lock_guard<std::mutex> lock(clients_mtx); |
| 70 | + if (!clients.count(client_id)) return; |
92 | 71 |
|
93 |
| - if(auto it = clients.find(client_id); it != clients.end()) { |
94 |
| - auto& session = *it->second; |
95 |
| - std::lock_guard<std::mutex> session_lock(session.mtx); |
96 |
| - |
97 |
| - if (!msg->binary) { |
98 |
| - webSocket.sendText("Error: Expected binary data"); |
99 |
| - fprintf(stderr, "Client %s sent text data\n", client_id.c_str()); |
100 |
| - return; |
101 |
| - } |
| 72 | + auto& session = *clients[client_id]; |
| 73 | + const auto& data = msg->str; |
| 74 | + const int16_t* pcm16 = reinterpret_cast<const int16_t*>(data.data()); |
| 75 | + size_t n_samples = data.size() / sizeof(int16_t); |
102 | 76 |
|
103 |
| - const auto &data = msg->str; |
104 |
| - size_t data_size = data.size(); |
105 |
| - |
106 |
| - if (data_size % sizeof(int16_t) != 0) { |
107 |
| - webSocket.sendText("Error: Invalid data size"); |
108 |
| - fprintf(stderr, "Invalid data size from %s: %zu\n", client_id.c_str(), data_size); |
109 |
| - return; |
110 |
| - } |
111 |
| - //PCM16 -> FLOAT32 |
112 |
| - const int16_t* pcm16 = reinterpret_cast<const int16_t*>(data.data()); |
113 |
| - const size_t num_samples = data_size / sizeof(int16_t); |
114 |
| - |
115 |
| - session.pcm_buffer.reserve(session.pcm_buffer.size() + num_samples); |
116 |
| - for(size_t i = 0; i < num_samples; ++i) { |
117 |
| - session.pcm_buffer.push_back(pcm16[i] / 32768.0f); |
118 |
| - } |
119 |
| - |
120 |
| - session.cv.notify_one(); |
| 77 | + std::lock_guard<std::mutex> session_lock(session.mtx); |
| 78 | + for (size_t i = 0; i < n_samples; i++) { |
| 79 | + session.pcm_buffer.push_back(pcm16[i] / 32768.0f); |
121 | 80 | }
|
122 | 81 | }
|
123 | 82 | }
|
124 | 83 |
|
125 |
| - void processClientAudio(std::string client_id) { |
126 |
| - constexpr int step_ms = 300; |
127 |
| - constexpr int n_samples_step = (1e-3 * step_ms) * WHISPER_SAMPLE_RATE; |
128 |
| - |
129 |
| - fprintf(stderr, "Started thread for: %s\n", client_id); |
130 |
| - while(true) { |
131 |
| - std::vector<float> audio_chunk; |
132 |
| - { |
133 |
| - |
134 |
| - //fprintf(stderr, "Started read chunk from: %s\n", client_id); |
135 |
| - std::unique_lock<std::mutex> lock(clients_mtx); |
136 |
| - if(!clients.count(client_id)) break; |
137 |
| - auto& session = *clients[client_id]; |
138 |
| - |
139 |
| - std::unique_lock<std::mutex> session_lock(session.mtx); |
140 |
| - session.cv.wait_for(session_lock, std::chrono::milliseconds(100), [&session] { |
141 |
| - return session.pcm_buffer.size() >= n_samples_step || session.terminate; |
142 |
| - }); |
143 |
| - |
144 |
| - if(session.terminate) break; |
145 |
| - |
146 |
| - if(session.pcm_buffer.size() >= n_samples_step) { |
147 |
| - audio_chunk.assign(session.pcm_buffer.begin(), session.pcm_buffer.begin() + n_samples_step); |
148 |
| - session.pcm_buffer.erase(session.pcm_buffer.begin(), session.pcm_buffer.begin() + n_samples_step); |
149 |
| - } |
150 |
| - |
151 |
| - size_t available = session.pcm_buffer.size(); |
152 |
| - if(available >= n_samples_step) { |
153 |
| - size_t take = std::min(available, (size_t)n_samples_step); |
154 |
| - audio_chunk.assign( |
155 |
| - session.pcm_buffer.begin(), |
156 |
| - session.pcm_buffer.begin() + take |
157 |
| - ); |
158 |
| - session.pcm_buffer.erase( |
159 |
| - session.pcm_buffer.begin(), |
160 |
| - session.pcm_buffer.begin() + take |
161 |
| - ); |
162 |
| - } |
163 |
| - //fprintf(stderr, "End of read chunk: %s\n", client_id); |
164 |
| - } |
| 84 | + void processQueues() { |
| 85 | + while (running) { |
| 86 | + std::this_thread::sleep_for(std::chrono::milliseconds(100)); |
165 | 87 |
|
166 |
| - if(!audio_chunk.empty()) { |
167 |
| - |
168 |
| - //fprintf(stderr, "Good, chunk not empty for: %s\n", client_id); |
169 |
| - whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH); |
170 |
| - wparams.n_threads = params.n_threads; |
171 |
| - wparams.language = "en"; |
172 |
| - wparams.print_realtime = false; |
173 |
| - wparams.print_progress = false; |
174 |
| - wparams.single_segment = true; |
175 |
| - wparams.max_tokens = 32; |
176 |
| - |
177 |
| - if(whisper_full(ctx, wparams, audio_chunk.data(), audio_chunk.size()) == 0) { |
178 |
| - |
179 |
| - fprintf(stderr, "whisper_full == 0: %s\n", client_id); |
180 |
| - const int n_segments = whisper_full_n_segments(ctx); |
181 |
| - for(int i = 0; i < n_segments; ++i) { |
182 |
| - const char* text = whisper_full_get_segment_text(ctx, i); |
183 |
| - fprintf(stdout, "[Client %s] %s\n", client_id, text); |
| 88 | + std::lock_guard<std::mutex> lock(clients_mtx); |
| 89 | + for (auto& [id, session] : clients) { |
| 90 | + std::lock_guard<std::mutex> session_lock(session->mtx); |
| 91 | + if (session->pcm_buffer.size() < CHUNK_SIZE) continue; |
| 92 | + |
| 93 | + std::vector<float> chunk( |
| 94 | + session->pcm_buffer.begin(), |
| 95 | + session->pcm_buffer.begin() + CHUNK_SIZE |
| 96 | + ); |
| 97 | + session->pcm_buffer.erase( |
| 98 | + session->pcm_buffer.begin(), |
| 99 | + session->pcm_buffer.begin() + CHUNK_SIZE |
| 100 | + ); |
| 101 | + |
| 102 | + { |
| 103 | + std::lock_guard<std::mutex> ctx_lock(g_ctx_mtx); |
| 104 | + whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); |
| 105 | + wparams.print_progress = false; |
| 106 | + wparams.print_realtime = false; |
| 107 | + wparams.single_segment = true; |
| 108 | + |
| 109 | + if (whisper_full(g_ctx, wparams, chunk.data(), chunk.size()) == 0) { |
| 110 | + const char* text = whisper_full_get_segment_text(g_ctx, 0); |
| 111 | + printf("[%s] %s\n", id.c_str(), text); |
184 | 112 | }
|
| 113 | + whisper_reset_timings(g_ctx); |
185 | 114 | }
|
186 | 115 | }
|
187 | 116 | }
|
188 | 117 | }
|
189 | 118 | };
|
190 | 119 |
|
191 | 120 | int main(int argc, char** argv) {
|
192 |
| - server_params params; |
193 |
| - params.port = 9002; |
194 |
| - params.model = "ggml-large-v3-turbo.bin"; |
| 121 | + if (argc < 3) { |
| 122 | + //fprintf(stderr, "Usage: %s <port> <model_path>\n", argv[0]); |
| 123 | + |
| 124 | + WhisperServer server(9002, "ggml-large-v3-turbo.bin"); |
| 125 | + server.run(); |
| 126 | + return 0; |
| 127 | + } |
195 | 128 |
|
196 |
| - WhisperServer server(params); |
| 129 | + WhisperServer server(atoi(argv[1]), argv[2]); |
197 | 130 | server.run();
|
198 | 131 | return 0;
|
199 |
| -} |
| 132 | +} |
| 133 | + |
0 commit comments