Skip to content

Commit 6d00648

Browse files
author
lexasub
committed
third try
1 parent 98c916b commit 6d00648

File tree

1 file changed

+75
-141
lines changed

1 file changed

+75
-141
lines changed

examples/websocket-stream/stream.cpp

Lines changed: 75 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -1,199 +1,133 @@
1-
#include "common.h"
2-
#include "common-whisper.h"
31
#include "whisper.h"
42
#include "ixwebsocket/IXWebSocketServer.h"
5-
#include "ixwebsocket/IXNetSystem.h"
6-
73
#include <atomic>
8-
#include <chrono>
9-
#include <condition_variable>
10-
#include <cstdio>
11-
#include <fstream>
124
#include <mutex>
13-
#include <string>
14-
#include <thread>
15-
#include <unordered_map>
16-
#include <vector>
5+
#include <queue>
6+
7+
std::mutex g_ctx_mtx;
8+
whisper_context* g_ctx = nullptr;
9+
constexpr int CHUNK_SIZE = 3 * 16000;
1710

1811
struct ClientSession {
1912
std::vector<float> pcm_buffer;
2013
std::mutex mtx;
21-
std::condition_variable cv;
22-
std::atomic<bool> active{false};
23-
std::atomic<bool> terminate{false};
24-
};
25-
26-
struct server_params {
27-
int32_t port = 9002;
28-
int32_t n_threads = std::min(4, (int32_t)std::thread::hardware_concurrency());
29-
std::string model = "models/ggml-base.en.bin";
30-
bool use_gpu = true;
14+
std::atomic<bool> active{true};
3115
};
3216

3317
class WhisperServer {
3418
private:
35-
server_params params;
3619
ix::WebSocketServer server;
3720
std::unordered_map<std::string, std::unique_ptr<ClientSession>> clients;
3821
std::mutex clients_mtx;
39-
whisper_context* ctx = nullptr;
22+
std::thread processor_thread;
23+
std::atomic<bool> running{true};
4024

4125
public:
42-
WhisperServer(const server_params& params) : params(params), server(params.port, "0.0.0.0") {
43-
ix::initNetSystem();
44-
26+
WhisperServer(int port, const std::string& model_path) : server(port, "0.0.0.0") {
4527
whisper_context_params cparams = whisper_context_default_params();
46-
cparams.use_gpu = params.use_gpu;
47-
ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
28+
cparams.use_gpu = true;
29+
g_ctx = whisper_init_from_file(model_path.c_str());
4830

4931
server.setTLSOptions({});
50-
server.setOnClientMessageCallback([this](std::shared_ptr<ix::ConnectionState> connectionState, ix::WebSocket& webSocket, const ix::WebSocketMessagePtr& msg) {
51-
this->handleMessage(connectionState, webSocket, msg);
32+
server.setOnClientMessageCallback([this](auto&&... args) {
33+
handleMessage(args...);
5234
});
35+
36+
processor_thread = std::thread([this] { processQueues(); });
5337
}
5438

5539
~WhisperServer() {
40+
running = false;
5641
server.stop();
57-
ix::uninitNetSystem();
58-
if(ctx) whisper_free(ctx);
42+
if (processor_thread.joinable()) processor_thread.join();
43+
whisper_free(g_ctx);
5944
}
6045

6146
void run() {
6247
server.listenAndStart();
63-
fprintf(stderr, "Server started on port %d\n", params.port);
64-
65-
while(true) {
66-
std::this_thread::sleep_for(std::chrono::seconds(1));
67-
}
48+
while (running) std::this_thread::sleep_for(std::chrono::seconds(1));
6849
}
6950

7051
private:
71-
void handleMessage(std::shared_ptr<ix::ConnectionState> connectionState, ix::WebSocket& webSocket, const ix::WebSocketMessagePtr& msg) {
72-
const auto client_id = connectionState->getId();
52+
void handleMessage(std::shared_ptr<ix::ConnectionState> state,
53+
ix::WebSocket& ws,
54+
const ix::WebSocketMessagePtr& msg) {
55+
const std::string client_id = state->getId();
7356

74-
if(msg->type == ix::WebSocketMessageType::Open) {
75-
fprintf(stderr, "New client connected: %s\n", client_id);
57+
if (msg->type == ix::WebSocketMessageType::Open) {
7658
std::lock_guard<std::mutex> lock(clients_mtx);
7759
clients[client_id] = std::make_unique<ClientSession>();
78-
clients[client_id]->active = true;
79-
std::thread(&WhisperServer::processClientAudio, this, client_id).detach();
8060
}
81-
else if(msg->type == ix::WebSocketMessageType::Close) {
82-
fprintf(stderr, "Client disconnected: %s\n", client_id);
61+
else if (msg->type == ix::WebSocketMessageType::Close) {
8362
std::lock_guard<std::mutex> lock(clients_mtx);
84-
if(clients.count(client_id)) {
85-
clients[client_id]->terminate = true;
86-
clients[client_id]->cv.notify_one();
63+
if (clients.count(client_id)) {
64+
clients[client_id]->active = false;
8765
clients.erase(client_id);
8866
}
8967
}
90-
else if(msg->type == ix::WebSocketMessageType::Message) {
91-
//std::lock_guard<std::mutex> lock(clients_mtx);
68+
else if (msg->type == ix::WebSocketMessageType::Message && msg->binary) {
69+
std::lock_guard<std::mutex> lock(clients_mtx);
70+
if (!clients.count(client_id)) return;
9271

93-
if(auto it = clients.find(client_id); it != clients.end()) {
94-
auto& session = *it->second;
95-
std::lock_guard<std::mutex> session_lock(session.mtx);
96-
97-
if (!msg->binary) {
98-
webSocket.sendText("Error: Expected binary data");
99-
fprintf(stderr, "Client %s sent text data\n", client_id.c_str());
100-
return;
101-
}
72+
auto& session = *clients[client_id];
73+
const auto& data = msg->str;
74+
const int16_t* pcm16 = reinterpret_cast<const int16_t*>(data.data());
75+
size_t n_samples = data.size() / sizeof(int16_t);
10276

103-
const auto &data = msg->str;
104-
size_t data_size = data.size();
105-
106-
if (data_size % sizeof(int16_t) != 0) {
107-
webSocket.sendText("Error: Invalid data size");
108-
fprintf(stderr, "Invalid data size from %s: %zu\n", client_id.c_str(), data_size);
109-
return;
110-
}
111-
//PCM16 -> FLOAT32
112-
const int16_t* pcm16 = reinterpret_cast<const int16_t*>(data.data());
113-
const size_t num_samples = data_size / sizeof(int16_t);
114-
115-
session.pcm_buffer.reserve(session.pcm_buffer.size() + num_samples);
116-
for(size_t i = 0; i < num_samples; ++i) {
117-
session.pcm_buffer.push_back(pcm16[i] / 32768.0f);
118-
}
119-
120-
session.cv.notify_one();
77+
std::lock_guard<std::mutex> session_lock(session.mtx);
78+
for (size_t i = 0; i < n_samples; i++) {
79+
session.pcm_buffer.push_back(pcm16[i] / 32768.0f);
12180
}
12281
}
12382
}
12483

125-
void processClientAudio(std::string client_id) {
126-
constexpr int step_ms = 300;
127-
constexpr int n_samples_step = (1e-3 * step_ms) * WHISPER_SAMPLE_RATE;
128-
129-
fprintf(stderr, "Started thread for: %s\n", client_id);
130-
while(true) {
131-
std::vector<float> audio_chunk;
132-
{
133-
134-
//fprintf(stderr, "Started read chunk from: %s\n", client_id);
135-
std::unique_lock<std::mutex> lock(clients_mtx);
136-
if(!clients.count(client_id)) break;
137-
auto& session = *clients[client_id];
138-
139-
std::unique_lock<std::mutex> session_lock(session.mtx);
140-
session.cv.wait_for(session_lock, std::chrono::milliseconds(100), [&session] {
141-
return session.pcm_buffer.size() >= n_samples_step || session.terminate;
142-
});
143-
144-
if(session.terminate) break;
145-
146-
if(session.pcm_buffer.size() >= n_samples_step) {
147-
audio_chunk.assign(session.pcm_buffer.begin(), session.pcm_buffer.begin() + n_samples_step);
148-
session.pcm_buffer.erase(session.pcm_buffer.begin(), session.pcm_buffer.begin() + n_samples_step);
149-
}
150-
151-
size_t available = session.pcm_buffer.size();
152-
if(available >= n_samples_step) {
153-
size_t take = std::min(available, (size_t)n_samples_step);
154-
audio_chunk.assign(
155-
session.pcm_buffer.begin(),
156-
session.pcm_buffer.begin() + take
157-
);
158-
session.pcm_buffer.erase(
159-
session.pcm_buffer.begin(),
160-
session.pcm_buffer.begin() + take
161-
);
162-
}
163-
//fprintf(stderr, "End of read chunk: %s\n", client_id);
164-
}
84+
void processQueues() {
85+
while (running) {
86+
std::this_thread::sleep_for(std::chrono::milliseconds(100));
16587

166-
if(!audio_chunk.empty()) {
167-
168-
//fprintf(stderr, "Good, chunk not empty for: %s\n", client_id);
169-
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH);
170-
wparams.n_threads = params.n_threads;
171-
wparams.language = "en";
172-
wparams.print_realtime = false;
173-
wparams.print_progress = false;
174-
wparams.single_segment = true;
175-
wparams.max_tokens = 32;
176-
177-
if(whisper_full(ctx, wparams, audio_chunk.data(), audio_chunk.size()) == 0) {
178-
179-
fprintf(stderr, "whisper_full == 0: %s\n", client_id);
180-
const int n_segments = whisper_full_n_segments(ctx);
181-
for(int i = 0; i < n_segments; ++i) {
182-
const char* text = whisper_full_get_segment_text(ctx, i);
183-
fprintf(stdout, "[Client %s] %s\n", client_id, text);
88+
std::lock_guard<std::mutex> lock(clients_mtx);
89+
for (auto& [id, session] : clients) {
90+
std::lock_guard<std::mutex> session_lock(session->mtx);
91+
if (session->pcm_buffer.size() < CHUNK_SIZE) continue;
92+
93+
std::vector<float> chunk(
94+
session->pcm_buffer.begin(),
95+
session->pcm_buffer.begin() + CHUNK_SIZE
96+
);
97+
session->pcm_buffer.erase(
98+
session->pcm_buffer.begin(),
99+
session->pcm_buffer.begin() + CHUNK_SIZE
100+
);
101+
102+
{
103+
std::lock_guard<std::mutex> ctx_lock(g_ctx_mtx);
104+
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
105+
wparams.print_progress = false;
106+
wparams.print_realtime = false;
107+
wparams.single_segment = true;
108+
109+
if (whisper_full(g_ctx, wparams, chunk.data(), chunk.size()) == 0) {
110+
const char* text = whisper_full_get_segment_text(g_ctx, 0);
111+
printf("[%s] %s\n", id.c_str(), text);
184112
}
113+
whisper_reset_timings(g_ctx);
185114
}
186115
}
187116
}
188117
}
189118
};
190119

191120
int main(int argc, char** argv) {
192-
server_params params;
193-
params.port = 9002;
194-
params.model = "ggml-large-v3-turbo.bin";
121+
if (argc < 3) {
122+
//fprintf(stderr, "Usage: %s <port> <model_path>\n", argv[0]);
123+
124+
WhisperServer server(9002, "ggml-large-v3-turbo.bin");
125+
server.run();
126+
return 0;
127+
}
195128

196-
WhisperServer server(params);
129+
WhisperServer server(atoi(argv[1]), argv[2]);
197130
server.run();
198131
return 0;
199-
}
132+
}
133+

0 commit comments

Comments
 (0)