Skip to content

Commit d98edff

Browse files
committed
upgrade sherpa-onnx
1 parent d4b4dd4 commit d98edff

File tree

12 files changed

+218
-1678
lines changed

12 files changed

+218
-1678
lines changed

CMakeLists.txt

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
cmake_minimum_required(VERSION 3.28)
2-
set(VCPKG_TARGET_TRIPLET x64-windows)
2+
set(VCPKG_TARGET_TRIPLET x64-windows-static)
33
project(realtime-bilingual-asr)
44

55
set(CMAKE_CXX_STANDARD 20)
6+
set(CMAKE_MSVC_RUNTIME_LIBRARY "MultiThreaded$<$<CONFIG:Debug>:Debug>")
67

78
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} /SUBSYSTEM:WINDOWS /ENTRY:mainCRTStartup")
89
add_compile_options("$<$<C_COMPILER_ID:MSVC>:/utf-8>")
@@ -13,6 +14,7 @@ IF (NOT CMAKE_BUILD_TYPE STREQUAL Debug)
1314
ENDIF ()
1415

1516
include_directories(${CMAKE_SOURCE_DIR}/wtfdanmaku/include)
17+
include_directories(${CMAKE_SOURCE_DIR}/sherpa-onnx/include)
1618

1719
find_package(ixwebsocket REQUIRED)
1820
find_package(nlohmann_json REQUIRED)
@@ -33,8 +35,9 @@ target_link_libraries(${PROJECT_NAME}
3335
mfplat
3436
mf
3537
wmcodecdspuuid
36-
"${CMAKE_SOURCE_DIR}/sherpa-onnx-c-api.lib"
3738
bcrypt
39+
"${CMAKE_SOURCE_DIR}/sherpa-onnx/lib/sherpa-onnx-c-api.lib"
40+
"${CMAKE_SOURCE_DIR}/sherpa-onnx/lib/sherpa-onnx-cxx-api.lib"
3841
)
3942
IF (NOT CMAKE_BUILD_TYPE STREQUAL Debug)
4043
target_link_libraries(${PROJECT_NAME}

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ JikkyoSubtitle: Real-time transcription and translation of system audio into Chi
2626

2727
**Prerequisites:**
2828

29+
0. **sherpa-onnx:** Obtain the pre-built shared libraries in the [Releases](https://github.com/k2-fsa/sherpa-onnx/releases/tag/v1.12.11) page.
2930
1. **Visual Studio with the "Desktop development with C++" workload:** Ensure it includes the MSVC compiler and CMake.
3031
2. **vcpkg:** Installed and integrated with your system (you've likely already done this if you have a vcpkg.json). Make sure `VCPKG_ROOT` environment variable is set and `vcpkg integrate install` has been run.
3132
3. **Git:** For fetching the project (if applicable).

SpeechRecognition.cpp

Lines changed: 62 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -33,23 +33,35 @@ void SpeechRecognition::init() {
3333
configFile >> config;
3434

3535
// 初始化加载 ASR 模型
36-
auto asr_model_path = config["asr"]["sense_voice"]["model_path"].get<std::string>();
37-
auto asr_model_lang = config["asr"]["sense_voice"]["language"].get<std::string>();
38-
SherpaOnnxOfflineSenseVoiceModelConfig sense_voice_config{
39-
asr_model_path.c_str(),
40-
asr_model_lang.c_str(),
41-
config["asr"]["sense_voice"]["num_threads"].get<int>()
42-
};
4336
// Offline model config
4437
SherpaOnnxOfflineModelConfig offline_model_config;
4538
memset(&offline_model_config, 0, sizeof(offline_model_config));
4639
offline_model_config.debug = 0;
4740
offline_model_config.num_threads = config["asr"]["num_threads"].get<int>();;
4841
auto asr_model_provider = config["asr"]["provider"].get<std::string>();
4942
offline_model_config.provider = asr_model_provider.c_str();
50-
auto asr_model_token_path = config["asr"]["token_path"].get<std::string>();
51-
offline_model_config.tokens = asr_model_token_path.c_str();
52-
offline_model_config.sense_voice = sense_voice_config;
43+
if (config["asr"].contains("sense_voice")) {
44+
auto asr_model_path = config["asr"]["sense_voice"]["model_path"].get<std::string>();
45+
auto asr_model_lang = config["asr"]["sense_voice"]["language"].get<std::string>();
46+
auto asr_model_token_path = config["asr"]["sense_voice"]["token_path"].get<std::string>();
47+
offline_model_config.tokens = asr_model_token_path.c_str();
48+
SherpaOnnxOfflineSenseVoiceModelConfig sense_voice_config{
49+
asr_model_path.c_str(),
50+
asr_model_lang.c_str(),
51+
config["asr"]["num_threads"].get<int>()
52+
};
53+
offline_model_config.sense_voice = sense_voice_config;
54+
} else if (config["asr"].contains("dolphin")) {
55+
auto asr_model_path = config["asr"]["dolphin"]["model_path"].get<std::string>();
56+
auto asr_model_token_path = config["asr"]["dolphin"]["token_path"].get<std::string>();
57+
offline_model_config.tokens = asr_model_token_path.c_str();
58+
SherpaOnnxOfflineDolphinModelConfig dolphin_config{
59+
asr_model_path.c_str(),
60+
};
61+
offline_model_config.dolphin = dolphin_config;
62+
} else {
63+
throw std::runtime_error("Failed to init config, unsupported ASR model type.");
64+
}
5365

5466
// Recognizer config
5567
SherpaOnnxOfflineRecognizerConfig recognizer_config;
@@ -58,46 +70,54 @@ void SpeechRecognition::init() {
5870
recognizer_config.model_config = offline_model_config;
5971

6072
recognizer = SherpaOnnxCreateOfflineRecognizer(&recognizer_config);
61-
6273
if (recognizer == nullptr) {
6374
throw std::runtime_error("Please check your recognizer config!\n");
6475
}
6576

6677
// VAD Config
6778
SherpaOnnxVadModelConfig vadConfig;
6879
memset(&vadConfig, 0, sizeof(vadConfig));
69-
auto vad_model_path = config["vad"]["silero_vad"]["model"].get<std::string>();
70-
vadConfig.silero_vad.model = vad_model_path.c_str();
71-
vadConfig.interrupt_threshold = config["vad"]["interrupt_threshold"].get<float>();;
72-
vadConfig.silero_vad.threshold = config["vad"]["silero_vad"]["threshold"].get<float>();;
73-
vadConfig.silero_vad.min_silence_duration = config["vad"]["silero_vad"]["min_silence_duration"].get<float>();
74-
vadConfig.silero_vad.min_speech_duration = config["vad"]["silero_vad"]["min_speech_duration"].get<float>();;
75-
vadConfig.silero_vad.max_speech_duration = config["vad"]["silero_vad"]["max_speech_duration"].get<float>();;
76-
vadConfig.silero_vad.window_size = 512;
7780
vadConfig.sample_rate = modelSampleRate;
78-
vadConfig.num_threads =
79-
vadConfig.num_threads = config["vad"]["num_threads"].get<int>();;
81+
vadConfig.num_threads = vadConfig.num_threads = config["vad"]["num_threads"].get<int>();;
8082
vadConfig.debug = 0;
8183
vadConfig.provider = "cpu";
84+
if (config["vad"].contains("silero_vad")) {
85+
auto vad_model_path = config["vad"]["silero_vad"]["model"].get<std::string>();
86+
vadConfig.silero_vad.model = vad_model_path.c_str();
87+
vadConfig.silero_vad.threshold = config["vad"]["silero_vad"]["threshold"].get<float>();;
88+
vadConfig.silero_vad.min_silence_duration = config["vad"]["silero_vad"]["min_silence_duration"].get<float>();
89+
vadConfig.silero_vad.min_speech_duration = config["vad"]["silero_vad"]["min_speech_duration"].get<float>();;
90+
vadConfig.silero_vad.max_speech_duration = config["vad"]["silero_vad"]["max_speech_duration"].get<float>();;
91+
vadConfig.silero_vad.window_size = 512;
92+
} else if (config["vad"].contains("ten_vad")) {
93+
auto vad_model_path = config["vad"]["ten_vad"]["model"].get<std::string>();
94+
vadConfig.ten_vad.model = vad_model_path.c_str();
95+
vadConfig.ten_vad.threshold = config["vad"]["ten_vad"]["threshold"].get<float>();;
96+
vadConfig.ten_vad.min_silence_duration = config["vad"]["ten_vad"]["min_silence_duration"].get<float>();
97+
vadConfig.ten_vad.min_speech_duration = config["vad"]["ten_vad"]["min_speech_duration"].get<float>();;
98+
vadConfig.ten_vad.max_speech_duration = config["vad"]["ten_vad"]["max_speech_duration"].get<float>();;
99+
vadConfig.ten_vad.window_size = 256;
100+
} else {
101+
throw std::runtime_error("Failed to init config, unsupported VAD model type.");
102+
}
82103

83104
vad = SherpaOnnxCreateVoiceActivityDetector(&vadConfig, 30);
84-
85105
if (vad == nullptr) {
86106
SherpaOnnxDestroyOfflineRecognizer(recognizer);
87107
throw std::runtime_error("Please check your vad config!\n");
88108
}
89109
std::cout << "VAD & ASR model loaded" << std::endl;
90110

91111
// LLM Params
92-
promptTemplate = config["llm"]["prompt_template"].get<std::string>();
93-
modelName = config["llm"]["model_name"].get<std::string>();
94-
modelAuth = config["llm"]["auth_key"].get<std::string>();
95-
llmServer = config["llm"]["api_base"].get<std::string>();
96-
isLlamaCpp = config["llm"]["is_llama_cpp"].get<bool>();
97-
isSakuraLLM = config["llm"]["is_sakura_llm"].get<bool>();
98-
modelMaxTokens = config["llm"]["max_tokens"].get<int>();
99-
modelTemperature = config["llm"]["temperature"].get<float>();
100-
modelTopP = config["llm"]["top_p"].get<float>();
112+
remoteLLMConfig.apiBaseUrl = config["llm"]["api_base"].get<std::string>();
113+
remoteLLMConfig.apiToken = config["llm"]["api_token"].get<std::string>();
114+
remoteLLMConfig.modelName = config["llm"]["model_name"].get<std::string>();
115+
remoteLLMConfig.isSakuraLLM = config["llm"].value("is_sakura_llm", false);
116+
remoteLLMConfig.promptTemplate = config["llm"]["prompt_template"].get<std::string>();
117+
remoteLLMConfig.engineType = config["llm"].value("is_llama_cpp", false) ? LLaMA_CPP : StandardOpenAI;
118+
remoteLLMConfig.samplingConfig.maxTokens = config["llm"].value("max_tokens", 512);
119+
remoteLLMConfig.samplingConfig.temperature = config["llm"].value("temperature", 0.1);
120+
remoteLLMConfig.samplingConfig.topP = config["llm"].value("topP", 0.3);
101121

102122
// Init ASR Handler
103123
asrCallback = [this](short *input, int32_t n_samples,
@@ -124,9 +144,9 @@ void SpeechRecognition::init() {
124144

125145
const SherpaOnnxOfflineStream *stream = SherpaOnnxCreateOfflineStream(recognizer);
126146

127-
SherpaOnnxAcceptWaveformOffline(stream, modelSampleRate, tail_paddings, 4800);
147+
SherpaOnnxAcceptWaveformOffline(stream, modelSampleRate, tailPaddings, 4800);
128148
SherpaOnnxAcceptWaveformOffline(stream, modelSampleRate, segment->samples, segment->n);
129-
SherpaOnnxAcceptWaveformOffline(stream, modelSampleRate, tail_paddings, 4800);
149+
SherpaOnnxAcceptWaveformOffline(stream, modelSampleRate, tailPaddings, 4800);
130150

131151
SherpaOnnxDecodeOfflineStream(recognizer, stream);
132152

@@ -146,7 +166,6 @@ void SpeechRecognition::init() {
146166
subtitles.emplace(text, result->lang);
147167
}
148168

149-
150169
SherpaOnnxDestroyOfflineRecognizerResult(result);
151170
SherpaOnnxDestroyOfflineStream(stream);
152171
SherpaOnnxDestroySpeechSegment(segment);
@@ -169,35 +188,35 @@ std::string SpeechRecognition::getTranslate(const std::string &text) {
169188
nonSpaceText.erase(std::ranges::remove_if(nonSpaceText,
170189
[](unsigned char c) { return std::isspace(c); }).begin(),
171190
nonSpaceText.end());
172-
ss << llmServer << "/v1/completions";
191+
ss << remoteLLMConfig.apiBaseUrl << "/v1/completions";
173192
const std::string url = ss.str();
174193
ix::HttpRequestArgsPtr args = httpClient.createRequest();
175194
ix::WebSocketHttpHeaders headers;
176-
headers["Authorization"] = "Bearer " + modelAuth;
195+
headers["Authorization"] = "Bearer " + remoteLLMConfig.apiToken;
177196
headers["content-type"] = "application/json";
178197
args->extraHeaders = headers;
179198
json payload = {
180-
{"model", modelName},
181-
{"max_tokens", modelMaxTokens},
182-
{"temperature", modelTemperature},
183-
{"top_p", modelTopP}
199+
{"model", remoteLLMConfig.modelName},
200+
{"max_tokens", remoteLLMConfig.samplingConfig.maxTokens},
201+
{"temperature", remoteLLMConfig.samplingConfig.temperature},
202+
{"top_p", remoteLLMConfig.samplingConfig.topP}
184203
};
185204

186205
// Use promptTemplate and format it
187-
std::string prompt = promptTemplate;
206+
std::string prompt = remoteLLMConfig.promptTemplate;
188207
size_t pos = prompt.find("%TEXT%");
189208
if (pos != std::string::npos) {
190209
prompt.replace(pos, 6, nonSpaceText);
191210
}
192-
if (isSakuraLLM) {
211+
if (remoteLLMConfig.isSakuraLLM) {
193212
payload["stop"] = {"<|im_end|>", "<|im_start|>"};
194213
}
195214
payload["prompt"] = prompt;
196215

197216
ix::HttpResponsePtr out = httpClient.post(url, payload.dump(), args);
198217
if (out->errorCode == ix::HttpErrorCode::Ok) {
199218
json llm_result = json::parse(out->body);
200-
if (isLlamaCpp) {
219+
if (remoteLLMConfig.engineType == LLaMA_CPP) {
201220
return llm_result["content"];
202221
}
203222
return llm_result["choices"][0]["text"];

SpeechRecognition.h

Lines changed: 39 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,28 @@
1313
#include <thread>
1414

1515
#include "AudioCapture.h"
16-
#include "c-api.h"
16+
#include "sherpa-onnx/c-api/c-api.h"
17+
18+
enum LLMEngineType {
19+
StandardOpenAI = 0,
20+
LLaMA_CPP = 1
21+
};
22+
23+
struct SamplingConfig {
24+
int maxTokens;
25+
float temperature;
26+
float topP;
27+
};
28+
29+
struct RemoteLLMConfig {
30+
std::string apiToken;
31+
std::string apiBaseUrl;
32+
std::string modelName;
33+
bool isSakuraLLM;
34+
std::string promptTemplate;
35+
LLMEngineType engineType = StandardOpenAI;
36+
SamplingConfig samplingConfig;
37+
};
1738

1839
class SpeechSubtitle {
1940
public:
@@ -62,40 +83,34 @@ class SpeechRecognition {
6283
private:
6384
void capture();
6485

65-
// Loads configuration from the JSON file
6686
void loadConfig();
6787

6888
static bool initNetSystem();
6989

70-
ix::HttpClient httpClient;
71-
AudioCapture audioCapture;
72-
73-
std::string configFilePath; // Path to the configuration file
90+
// Global config / status
91+
std::string configFilePath;
92+
bool running = false;
93+
std::mutex subtitlesMutex;
7494

75-
// Configuration members loaded from the JSON file
76-
SherpaOnnxVadModelConfig vadConfig;
77-
std::string promptTemplate;
78-
std::string modelName;
79-
std::string modelAuth;
80-
int modelMaxTokens = 512;
81-
float modelTemperature = 0.1;
82-
float modelTopP = 0.3;
83-
std::string llmServer;
84-
bool isLlamaCpp = false;
85-
bool isSakuraLLM = false;
95+
// Audio capture pipeline
96+
AudioCapture audioCapture;
97+
std::thread captureThread;
8698

99+
// ASR pipeline
87100
int modelSampleRate = 16000;
88-
float tail_paddings[4800] = {0.}; // 0.3 seconds at 16 kHz sample rate
89-
const SherpaOnnxOfflineRecognizer *recognizer = nullptr;
101+
float tailPaddings[4800] = {0.}; // 0.3 seconds at 16 kHz sample rate
90102
SherpaOnnxOfflineRecognizerConfig recognizer_config;
91-
SherpaOnnxOfflineModelConfig offline_model_config;
92-
SherpaOnnxVoiceActivityDetector *vad = nullptr;
103+
const SherpaOnnxOfflineRecognizer *recognizer = nullptr;
104+
SherpaOnnxVadModelConfig vadConfig;
105+
const SherpaOnnxVoiceActivityDetector *vad = nullptr;
93106
std::function<void(short *, int32_t, int32_t)> asrCallback;
94107

95-
bool running = false;
108+
// LLM based translate pipeline
109+
ix::HttpClient httpClient;
110+
RemoteLLMConfig remoteLLMConfig;
111+
112+
// Render queue
96113
std::queue<SpeechSubtitle> subtitles;
97-
std::mutex subtitlesMutex;
98-
std::thread captureThread;
99114
};
100115

101116
#endif

0 commit comments

Comments
 (0)