Skip to content

Commit 22f6a81

Browse files
author
lexasub
committed
dirty experiments with websocket-stream
1 parent f8a3509 commit 22f6a81

File tree

5 files changed

+297
-0
lines changed

5 files changed

+297
-0
lines changed

examples/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ else()
107107
add_subdirectory(quantize)
108108
if (WHISPER_SDL2)
109109
add_subdirectory(stream)
110+
add_subdirectory(websocket-stream)
110111
add_subdirectory(command)
111112
add_subdirectory(talk-llama)
112113
add_subdirectory(lsp)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
if (WHISPER_SDL2)
2+
set(TARGET whisper-stream-websocket)
3+
add_executable(${TARGET} stream.cpp)
4+
find_package(ixwebsocket)
5+
include(DefaultTargetOptions)
6+
7+
target_link_libraries(${TARGET} PRIVATE common common-sdl whisper ixwebsocket z ${CMAKE_THREAD_LIBS_INIT})
8+
9+
install(TARGETS ${TARGET} RUNTIME)
10+
endif ()

examples/websocket-stream/README.md

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# whisper.cpp/examples/stream
2+
3+
This is a naive example of performing real-time inference on audio from your microphone.
4+
The `whisper-stream` tool samples the audio every half a second and runs the transcription continously.
5+
More info is available in [issue #10](https://github.com/ggerganov/whisper.cpp/issues/10).
6+
7+
```bash
8+
./build/bin/whisper-stream -m ./models/ggml-base.en.bin -t 8 --step 500 --length 5000
9+
```
10+
11+
https://user-images.githubusercontent.com/1991296/194935793-76afede7-cfa8-48d8-a80f-28ba83be7d09.mp4
12+
13+
## Sliding window mode with VAD
14+
15+
Setting the `--step` argument to `0` enables the sliding window mode:
16+
17+
```bash
18+
./build/bin/whisper-stream -m ./models/ggml-base.en.bin -t 6 --step 0 --length 30000 -vth 0.6
19+
```
20+
21+
In this mode, the tool will transcribe only after some speech activity is detected. A very
22+
basic VAD detector is used, but in theory a more sophisticated approach can be added. The
23+
`-vth` argument determines the VAD threshold - higher values will make it detect silence more often.
24+
It's best to tune it to the specific use case, but a value around `0.6` should be OK in general.
25+
When silence is detected, it will transcribe the last `--length` milliseconds of audio and output
26+
a transcription block that is suitable for parsing.
27+
28+
## Building
29+
30+
The `whisper-stream` tool depends on SDL2 library to capture audio from the microphone. You can build it like this:
31+
32+
```bash
33+
# Install SDL2
34+
# On Debian based linux distributions:
35+
sudo apt-get install libsdl2-dev
36+
37+
# On Fedora Linux:
38+
sudo dnf install SDL2 SDL2-devel
39+
40+
# Install SDL2 on Mac OS
41+
brew install sdl2
42+
43+
cmake -B build -DWHISPER_SDL2=ON
44+
cmake --build build --config Release
45+
46+
./build/bin/whisper-stream
47+
```
48+
49+
## Web version
50+
51+
This tool can also run in the browser: [examples/stream.wasm](/examples/stream.wasm)

examples/websocket-stream/stream.cpp

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
#include "common.h"
2+
#include "common-whisper.h"
3+
#include "whisper.h"
4+
#include "ixwebsocket/IXWebSocketServer.h"
5+
#include "ixwebsocket/IXNetSystem.h"
6+
7+
#include <atomic>
8+
#include <chrono>
9+
#include <condition_variable>
10+
#include <cstdio>
11+
#include <fstream>
12+
#include <mutex>
13+
#include <string>
14+
#include <thread>
15+
#include <unordered_map>
16+
#include <vector>
17+
18+
struct ClientSession {
19+
std::vector<float> pcm_buffer;
20+
std::mutex mtx;
21+
std::condition_variable cv;
22+
std::atomic<bool> active{false};
23+
std::atomic<bool> terminate{false};
24+
};
25+
26+
struct server_params {
27+
int32_t port = 9002;
28+
int32_t n_threads = std::min(4, (int32_t)std::thread::hardware_concurrency());
29+
std::string model = "models/ggml-base.en.bin";
30+
bool use_gpu = true;
31+
};
32+
33+
class WhisperServer {
34+
private:
35+
server_params params;
36+
ix::WebSocketServer server;
37+
std::unordered_map<std::string, std::unique_ptr<ClientSession>> clients;
38+
std::mutex clients_mtx;
39+
whisper_context* ctx = nullptr;
40+
41+
public:
42+
WhisperServer(const server_params& params) : params(params), server(params.port, "0.0.0.0") {
43+
ix::initNetSystem();
44+
45+
whisper_context_params cparams = whisper_context_default_params();
46+
cparams.use_gpu = params.use_gpu;
47+
ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
48+
49+
server.setTLSOptions({});
50+
server.setOnClientMessageCallback([this](std::shared_ptr<ix::ConnectionState> connectionState, ix::WebSocket& webSocket, const ix::WebSocketMessagePtr& msg) {
51+
this->handleMessage(connectionState, webSocket, msg);
52+
});
53+
}
54+
55+
~WhisperServer() {
56+
server.stop();
57+
ix::uninitNetSystem();
58+
if(ctx) whisper_free(ctx);
59+
}
60+
61+
void run() {
62+
server.listenAndStart();
63+
fprintf(stderr, "Server started on port %d\n", params.port);
64+
65+
while(true) {
66+
std::this_thread::sleep_for(std::chrono::seconds(1));
67+
}
68+
}
69+
70+
private:
71+
void handleMessage(std::shared_ptr<ix::ConnectionState> connectionState, ix::WebSocket& webSocket, const ix::WebSocketMessagePtr& msg) {
72+
const auto client_id = connectionState->getId();
73+
74+
if(msg->type == ix::WebSocketMessageType::Open) {
75+
fprintf(stderr, "New client connected: %s\n", client_id);
76+
std::lock_guard<std::mutex> lock(clients_mtx);
77+
clients[client_id] = std::make_unique<ClientSession>();
78+
clients[client_id]->active = true;
79+
std::thread(&WhisperServer::processClientAudio, this, client_id).detach();
80+
}
81+
else if(msg->type == ix::WebSocketMessageType::Close) {
82+
fprintf(stderr, "Client disconnected: %s\n", client_id);
83+
std::lock_guard<std::mutex> lock(clients_mtx);
84+
if(clients.count(client_id)) {
85+
clients[client_id]->terminate = true;
86+
clients[client_id]->cv.notify_one();
87+
clients.erase(client_id);
88+
}
89+
}
90+
else if(msg->type == ix::WebSocketMessageType::Message) {
91+
std::lock_guard<std::mutex> lock(clients_mtx);
92+
if(auto it = clients.find(client_id); it != clients.end()) {
93+
auto& session = *it->second;
94+
std::lock_guard<std::mutex> session_lock(session.mtx);
95+
96+
//PCM16 -> FLOAT32
97+
const int16_t* pcm16 = reinterpret_cast<const int16_t*>(msg->str.data());
98+
const size_t num_samples = msg->str.size() / sizeof(int16_t);
99+
100+
session.pcm_buffer.reserve(session.pcm_buffer.size() + num_samples);
101+
for(size_t i = 0; i < num_samples; ++i) {
102+
session.pcm_buffer.push_back(pcm16[i] / 32768.0f);
103+
}
104+
105+
session.cv.notify_one();
106+
}
107+
}
108+
}
109+
110+
void processClientAudio(std::string client_id) {
111+
constexpr int step_ms = 3000;
112+
constexpr int n_samples_step = (1e-3 * step_ms) * WHISPER_SAMPLE_RATE;
113+
114+
while(true) {
115+
std::vector<float> audio_chunk;
116+
{
117+
std::unique_lock<std::mutex> lock(clients_mtx);
118+
if(!clients.count(client_id)) break;
119+
auto& session = *clients[client_id];
120+
121+
std::unique_lock<std::mutex> session_lock(session.mtx);
122+
session.cv.wait_for(session_lock, std::chrono::milliseconds(100), [&session] {
123+
return session.pcm_buffer.size() >= n_samples_step || session.terminate;
124+
});
125+
126+
if(session.terminate) break;
127+
128+
if(session.pcm_buffer.size() >= n_samples_step) {
129+
audio_chunk.assign(session.pcm_buffer.begin(), session.pcm_buffer.begin() + n_samples_step);
130+
session.pcm_buffer.erase(session.pcm_buffer.begin(), session.pcm_buffer.begin() + n_samples_step);
131+
}
132+
}
133+
134+
if(!audio_chunk.empty()) {
135+
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
136+
wparams.n_threads = params.n_threads;
137+
wparams.print_progress = false;
138+
139+
if(whisper_full(ctx, wparams, audio_chunk.data(), audio_chunk.size()) == 0) {
140+
const int n_segments = whisper_full_n_segments(ctx);
141+
for(int i = 0; i < n_segments; ++i) {
142+
const char* text = whisper_full_get_segment_text(ctx, i);
143+
fprintf(stdout, "[Client %s] %s\n", client_id, text);
144+
}
145+
}
146+
}
147+
}
148+
}
149+
};
150+
151+
int main(int argc, char** argv) {
152+
server_params params;
153+
params.port = 9002;
154+
params.model = "../models/for-tests-ggml-base.bin";
155+
156+
WhisperServer server(params);
157+
server.run();
158+
return 0;
159+
}
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# test_whisper_server.py
2+
import websockets
3+
import asyncio
4+
import numpy as np
5+
import threading
6+
import time
7+
import json
8+
from queue import Queue
9+
10+
# Конфигурация теста
11+
SERVER_URL = "ws://localhost:9002"
12+
NUM_CLIENTS = 3
13+
TEST_DURATION = 15 # сек
14+
SAMPLE_RATE = 16000
15+
16+
# Очередь для сбора результатов
17+
results = Queue()
18+
19+
# Генератор тестового аудио (синусоида с разными частотами для каждого клиента)
20+
def generate_audio(client_id, duration_sec):
21+
t = np.linspace(0, duration_sec, int(SAMPLE_RATE * duration_sec), False)
22+
freq = 440 + client_id * 100 # Уникальная частота для каждого клиента
23+
audio = np.sin(2 * np.pi * freq * t) * 0.5
24+
return (audio * 32767).astype(np.int16).tobytes()
25+
26+
async def client_worker(client_id):
27+
try:
28+
async with websockets.connect(SERVER_URL) as ws:
29+
print(f"Client {client_id} connected")
30+
31+
start_time = time.time()
32+
while time.time() - start_time < TEST_DURATION:
33+
audio_data = generate_audio(client_id, 0.5) # 500ms chunks
34+
await ws.send(audio_data)
35+
await asyncio.sleep(0.1)
36+
37+
await ws.close()
38+
results.put((client_id, "OK"))
39+
except Exception as e:
40+
results.put((client_id, f"Error: {str(e)}"))
41+
42+
def run_server():
43+
import subprocess
44+
# Замените на путь к вашему бинарнику
45+
subprocess.run(["./bin/whisper-stream-websocket"], check=True)
46+
47+
def test_multi_clients():
48+
# Запуск сервера в отдельном потоке
49+
#server_thread = threading.Thread(target=run_server, daemon=True)
50+
#server_thread.start()
51+
#time.sleep(2) # Даем серверу время на запуск
52+
53+
# Запуск клиентов
54+
loop = asyncio.new_event_loop()
55+
asyncio.set_event_loop(loop)
56+
57+
tasks = []
58+
for i in range(NUM_CLIENTS):
59+
tasks.append(client_worker(i))
60+
61+
loop.run_until_complete(asyncio.gather(*tasks))
62+
loop.close()
63+
64+
# Проверка результатов
65+
all_ok = True
66+
while not results.empty():
67+
client_id, status = results.get()
68+
print(f"Client {client_id}: {status}")
69+
if status != "OK":
70+
all_ok = False
71+
72+
assert all_ok, "Some clients failed"
73+
print("All clients finished successfully")
74+
75+
if __name__ == "__main__":
76+
test_multi_clients()

0 commit comments

Comments
 (0)