|
| 1 | +#include "common.h" |
| 2 | +#include "common-whisper.h" |
| 3 | +#include "whisper.h" |
| 4 | +#include "ixwebsocket/IXWebSocketServer.h" |
| 5 | +#include "ixwebsocket/IXNetSystem.h" |
| 6 | + |
| 7 | +#include <atomic> |
| 8 | +#include <chrono> |
| 9 | +#include <condition_variable> |
| 10 | +#include <cstdio> |
| 11 | +#include <fstream> |
| 12 | +#include <mutex> |
| 13 | +#include <string> |
| 14 | +#include <thread> |
| 15 | +#include <unordered_map> |
| 16 | +#include <vector> |
| 17 | + |
| 18 | +struct ClientSession { |
| 19 | + std::vector<float> pcm_buffer; |
| 20 | + std::mutex mtx; |
| 21 | + std::condition_variable cv; |
| 22 | + std::atomic<bool> active{false}; |
| 23 | + std::atomic<bool> terminate{false}; |
| 24 | +}; |
| 25 | + |
| 26 | +struct server_params { |
| 27 | + int32_t port = 9002; |
| 28 | + int32_t n_threads = std::min(4, (int32_t)std::thread::hardware_concurrency()); |
| 29 | + std::string model = "models/ggml-base.en.bin"; |
| 30 | + bool use_gpu = true; |
| 31 | +}; |
| 32 | + |
| 33 | +class WhisperServer { |
| 34 | +private: |
| 35 | + server_params params; |
| 36 | + ix::WebSocketServer server; |
| 37 | + std::unordered_map<std::string, std::unique_ptr<ClientSession>> clients; |
| 38 | + std::mutex clients_mtx; |
| 39 | + whisper_context* ctx = nullptr; |
| 40 | + |
| 41 | +public: |
| 42 | + WhisperServer(const server_params& params) : params(params), server(params.port, "0.0.0.0") { |
| 43 | + ix::initNetSystem(); |
| 44 | + |
| 45 | + whisper_context_params cparams = whisper_context_default_params(); |
| 46 | + cparams.use_gpu = params.use_gpu; |
| 47 | + ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams); |
| 48 | + |
| 49 | + server.setTLSOptions({}); |
| 50 | + server.setOnClientMessageCallback([this](std::shared_ptr<ix::ConnectionState> connectionState, ix::WebSocket& webSocket, const ix::WebSocketMessagePtr& msg) { |
| 51 | + this->handleMessage(connectionState, webSocket, msg); |
| 52 | + }); |
| 53 | + } |
| 54 | + |
| 55 | + ~WhisperServer() { |
| 56 | + server.stop(); |
| 57 | + ix::uninitNetSystem(); |
| 58 | + if(ctx) whisper_free(ctx); |
| 59 | + } |
| 60 | + |
| 61 | + void run() { |
| 62 | + server.listenAndStart(); |
| 63 | + fprintf(stderr, "Server started on port %d\n", params.port); |
| 64 | + |
| 65 | + while(true) { |
| 66 | + std::this_thread::sleep_for(std::chrono::seconds(1)); |
| 67 | + } |
| 68 | + } |
| 69 | + |
| 70 | +private: |
| 71 | + void handleMessage(std::shared_ptr<ix::ConnectionState> connectionState, ix::WebSocket& webSocket, const ix::WebSocketMessagePtr& msg) { |
| 72 | + const auto client_id = connectionState->getId(); |
| 73 | + |
| 74 | + if(msg->type == ix::WebSocketMessageType::Open) { |
| 75 | + fprintf(stderr, "New client connected: %s\n", client_id); |
| 76 | + std::lock_guard<std::mutex> lock(clients_mtx); |
| 77 | + clients[client_id] = std::make_unique<ClientSession>(); |
| 78 | + clients[client_id]->active = true; |
| 79 | + std::thread(&WhisperServer::processClientAudio, this, client_id).detach(); |
| 80 | + } |
| 81 | + else if(msg->type == ix::WebSocketMessageType::Close) { |
| 82 | + fprintf(stderr, "Client disconnected: %s\n", client_id); |
| 83 | + std::lock_guard<std::mutex> lock(clients_mtx); |
| 84 | + if(clients.count(client_id)) { |
| 85 | + clients[client_id]->terminate = true; |
| 86 | + clients[client_id]->cv.notify_one(); |
| 87 | + clients.erase(client_id); |
| 88 | + } |
| 89 | + } |
| 90 | + else if(msg->type == ix::WebSocketMessageType::Message) { |
| 91 | + std::lock_guard<std::mutex> lock(clients_mtx); |
| 92 | + if(auto it = clients.find(client_id); it != clients.end()) { |
| 93 | + auto& session = *it->second; |
| 94 | + std::lock_guard<std::mutex> session_lock(session.mtx); |
| 95 | + |
| 96 | + //PCM16 -> FLOAT32 |
| 97 | + const int16_t* pcm16 = reinterpret_cast<const int16_t*>(msg->str.data()); |
| 98 | + const size_t num_samples = msg->str.size() / sizeof(int16_t); |
| 99 | + |
| 100 | + session.pcm_buffer.reserve(session.pcm_buffer.size() + num_samples); |
| 101 | + for(size_t i = 0; i < num_samples; ++i) { |
| 102 | + session.pcm_buffer.push_back(pcm16[i] / 32768.0f); |
| 103 | + } |
| 104 | + |
| 105 | + session.cv.notify_one(); |
| 106 | + } |
| 107 | + } |
| 108 | + } |
| 109 | + |
| 110 | + void processClientAudio(std::string client_id) { |
| 111 | + constexpr int step_ms = 3000; |
| 112 | + constexpr int n_samples_step = (1e-3 * step_ms) * WHISPER_SAMPLE_RATE; |
| 113 | + |
| 114 | + while(true) { |
| 115 | + std::vector<float> audio_chunk; |
| 116 | + { |
| 117 | + std::unique_lock<std::mutex> lock(clients_mtx); |
| 118 | + if(!clients.count(client_id)) break; |
| 119 | + auto& session = *clients[client_id]; |
| 120 | + |
| 121 | + std::unique_lock<std::mutex> session_lock(session.mtx); |
| 122 | + session.cv.wait_for(session_lock, std::chrono::milliseconds(100), [&session] { |
| 123 | + return session.pcm_buffer.size() >= n_samples_step || session.terminate; |
| 124 | + }); |
| 125 | + |
| 126 | + if(session.terminate) break; |
| 127 | + |
| 128 | + if(session.pcm_buffer.size() >= n_samples_step) { |
| 129 | + audio_chunk.assign(session.pcm_buffer.begin(), session.pcm_buffer.begin() + n_samples_step); |
| 130 | + session.pcm_buffer.erase(session.pcm_buffer.begin(), session.pcm_buffer.begin() + n_samples_step); |
| 131 | + } |
| 132 | + } |
| 133 | + |
| 134 | + if(!audio_chunk.empty()) { |
| 135 | + whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); |
| 136 | + wparams.n_threads = params.n_threads; |
| 137 | + wparams.print_progress = false; |
| 138 | + |
| 139 | + if(whisper_full(ctx, wparams, audio_chunk.data(), audio_chunk.size()) == 0) { |
| 140 | + const int n_segments = whisper_full_n_segments(ctx); |
| 141 | + for(int i = 0; i < n_segments; ++i) { |
| 142 | + const char* text = whisper_full_get_segment_text(ctx, i); |
| 143 | + fprintf(stdout, "[Client %s] %s\n", client_id, text); |
| 144 | + } |
| 145 | + } |
| 146 | + } |
| 147 | + } |
| 148 | + } |
| 149 | +}; |
| 150 | + |
| 151 | +int main(int argc, char** argv) { |
| 152 | + server_params params; |
| 153 | + params.port = 9002; |
| 154 | + params.model = "../models/for-tests-ggml-base.bin"; |
| 155 | + |
| 156 | + WhisperServer server(params); |
| 157 | + server.run(); |
| 158 | + return 0; |
| 159 | +} |
0 commit comments