Skip to content

Commit 8e51a97

Browse files
authored
Add C++ runtime for silero_vad with RKNN (#2078)
1 parent 0703bc1 commit 8e51a97

File tree

12 files changed

+867
-16
lines changed

12 files changed

+867
-16
lines changed

c-api-examples/vad-whisper-c-api.c

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,11 @@ int32_t main() {
100100

101101
while (!is_eof) {
102102
if (i + window_size < wave->num_samples) {
103-
SherpaOnnxVoiceActivityDetectorAcceptWaveform(vad, wave->samples + i,
104-
window_size);
105-
}
106-
else {
107-
SherpaOnnxVoiceActivityDetectorFlush(vad);
108-
is_eof = 1;
103+
SherpaOnnxVoiceActivityDetectorAcceptWaveform(vad, wave->samples + i,
104+
window_size);
105+
} else {
106+
SherpaOnnxVoiceActivityDetectorFlush(vad);
107+
is_eof = 1;
109108
}
110109
while (!SherpaOnnxVoiceActivityDetectorEmpty(vad)) {
111110
const SherpaOnnxSpeechSegment *segment =

scripts/gtcrn/show.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
#!/usr/bin/env python3
2+
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
3+
4+
import onnxruntime
5+
import onnx
6+
7+
"""
8+
[key: "model_type"
9+
value: "gtcrn"
10+
, key: "comment"
11+
value: "gtcrn_simple"
12+
, key: "version"
13+
value: "1"
14+
, key: "sample_rate"
15+
value: "16000"
16+
, key: "model_url"
17+
value: "https://github.com/Xiaobin-Rong/gtcrn/blob/main/stream/onnx_models/gtcrn_simple.onnx"
18+
, key: "maintainer"
19+
value: "k2-fsa"
20+
, key: "comment2"
21+
value: "Please see also https://github.com/Xiaobin-Rong/gtcrn"
22+
, key: "conv_cache_shape"
23+
value: "2,1,16,16,33"
24+
, key: "tra_cache_shape"
25+
value: "2,3,1,1,16"
26+
, key: "inter_cache_shape"
27+
value: "2,1,33,16"
28+
, key: "n_fft"
29+
value: "512"
30+
, key: "hop_length"
31+
value: "256"
32+
, key: "window_length"
33+
value: "512"
34+
, key: "window_type"
35+
value: "hann_sqrt"
36+
]
37+
"""
38+
39+
"""
40+
NodeArg(name='mix', type='tensor(float)', shape=[1, 257, 1, 2])
41+
NodeArg(name='conv_cache', type='tensor(float)', shape=[2, 1, 16, 16, 33])
42+
NodeArg(name='tra_cache', type='tensor(float)', shape=[2, 3, 1, 1, 16])
43+
NodeArg(name='inter_cache', type='tensor(float)', shape=[2, 1, 33, 16])
44+
-----
45+
NodeArg(name='enh', type='tensor(float)', shape=[1, 257, 1, 2])
46+
NodeArg(name='conv_cache_out', type='tensor(float)', shape=[2, 1, 16, 16, 33])
47+
NodeArg(name='tra_cache_out', type='tensor(float)', shape=[2, 3, 1, 1, 16])
48+
NodeArg(name='inter_cache_out', type='tensor(float)', shape=[2, 1, 33, 16])
49+
"""
50+
51+
52+
def show(filename):
53+
model = onnx.load(filename)
54+
print(model.metadata_props)
55+
56+
session_opts = onnxruntime.SessionOptions()
57+
session_opts.log_severity_level = 3
58+
sess = onnxruntime.InferenceSession(
59+
filename, session_opts, providers=["CPUExecutionProvider"]
60+
)
61+
for i in sess.get_inputs():
62+
print(i)
63+
64+
print("-----")
65+
66+
for i in sess.get_outputs():
67+
print(i)
68+
69+
70+
def main():
71+
show("./gtcrn_simple.onnx")
72+
73+
74+
if __name__ == "__main__":
75+
main()

scripts/silero_vad/v4/export-onnx.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,94 @@
55
import torch
66
from onnxsim import simplify
77

8+
import torch
9+
from torch import Tensor
10+
11+
12+
def simple_pad(x: Tensor, pad: int) -> Tensor:
13+
# _0 = torch.slice(torch.slice(torch.slice(x), 1), 2, 1, torch.add(1, pad))
14+
_0 = x[:, :, 1 : 1 + pad]
15+
16+
left_pad = torch.flip(_0, [-1])
17+
# _1 = torch.slice(torch.slice(torch.slice(x), 1), 2, torch.sub(-1, pad), -1)
18+
19+
_1 = x[:, :, (-1 - pad) : -1]
20+
21+
right_pad = torch.flip(_1, [-1])
22+
_2 = torch.cat([left_pad, x, right_pad], 2)
23+
return _2
24+
25+
26+
class MyModule(torch.nn.Module):
27+
def __init__(self, m):
28+
super().__init__()
29+
self.m = m
30+
31+
def adaptive_normalization_forward(self, spect):
32+
m = self.m._model.adaptive_normalization
33+
_0 = simple_pad
34+
35+
# Note(fangjun): rknn uses fp16 by default, whose max value is 65504
36+
# so we need to re-write the computation for spect0
37+
# spect0 = torch.log1p(torch.mul(spect, 1048576))
38+
spect0 = torch.log1p(spect) + 13.86294
39+
40+
_1 = torch.eq(len(spect0.shape), 2)
41+
if _1:
42+
_2 = torch.unsqueeze(spect0, 0)
43+
spect1 = _2
44+
else:
45+
spect1 = spect0
46+
mean = torch.mean(spect1, [1], True)
47+
to_pad = m.to_pad
48+
mean0 = _0(
49+
mean,
50+
to_pad,
51+
)
52+
filter_ = m.filter_
53+
mean1 = torch.conv1d(mean0, filter_)
54+
mean_mean = torch.mean(mean1, [-1], True)
55+
spect2 = torch.add(spect1, torch.neg(mean_mean))
56+
return spect2
57+
58+
def forward(self, x: torch.Tensor, h: torch.Tensor, c: torch.Tensor):
59+
m = self.m._model
60+
61+
feature_extractor = m.feature_extractor
62+
x0 = (feature_extractor).forward(
63+
x,
64+
)
65+
norm = self.adaptive_normalization_forward(x0)
66+
x1 = torch.cat([x0, norm], 1)
67+
first_layer = m.first_layer
68+
x2 = (first_layer).forward(
69+
x1,
70+
)
71+
encoder = m.encoder
72+
x3 = (encoder).forward(
73+
x2,
74+
)
75+
decoder = m.decoder
76+
x4, h0, c0, = (decoder).forward(
77+
x3,
78+
h,
79+
c,
80+
)
81+
_0 = torch.mean(torch.squeeze(x4, 1), [1])
82+
out = torch.unsqueeze(_0, 1)
83+
return (out, h0, c0)
84+
885

986
@torch.no_grad()
1087
def main():
1188
m = torch.jit.load("./silero_vad.jit")
89+
m = MyModule(m)
1290
x = torch.rand((1, 512), dtype=torch.float32)
1391
h = torch.rand((2, 1, 64), dtype=torch.float32)
1492
c = torch.rand((2, 1, 64), dtype=torch.float32)
93+
m = torch.jit.script(m)
1594
torch.onnx.export(
16-
m._model,
95+
m,
1796
(x, h, c),
1897
"m.onnx",
1998
input_names=["x", "h", "c"],

scripts/silero_vad/v4/show.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python3
2-
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
2+
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
33

44
import onnxruntime
55
import onnx
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
#!/usr/bin/env python3
2+
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
3+
4+
# Please run this file on your rk3588 board
5+
6+
try:
7+
from rknnlite.api import RKNNLite
8+
except:
9+
print("Please run this file on your board (linux + aarch64 + npu)")
10+
print("You need to install rknn_toolkit_lite2")
11+
print(
12+
" from https://github.com/airockchip/rknn-toolkit2/tree/master/rknn-toolkit-lite2/packages"
13+
)
14+
print(
15+
"https://github.com/airockchip/rknn-toolkit2/blob/v2.1.0/rknn-toolkit-lite2/packages/rknn_toolkit_lite2-2.1.0-cp310-cp310-linux_aarch64.whl"
16+
)
17+
print("is known to work")
18+
raise
19+
20+
import time
21+
from pathlib import Path
22+
from typing import Tuple
23+
24+
import numpy as np
25+
import soundfile as sf
26+
27+
28+
def load_audio(filename: str) -> Tuple[np.ndarray, int]:
29+
data, sample_rate = sf.read(
30+
filename,
31+
always_2d=True,
32+
dtype="float32",
33+
)
34+
data = data[:, 0] # use only the first channel
35+
36+
samples = np.ascontiguousarray(data)
37+
return samples, sample_rate
38+
39+
40+
def init_model(filename, target_platform="rk3588"):
41+
if not Path(filename).is_file():
42+
exit(f"{filename} does not exist")
43+
44+
rknn_lite = RKNNLite(verbose=False)
45+
ret = rknn_lite.load_rknn(path=filename)
46+
if ret != 0:
47+
exit(f"Load model {filename} failed!")
48+
49+
ret = rknn_lite.init_runtime(core_mask=RKNNLite.NPU_CORE_0)
50+
if ret != 0:
51+
exit(f"Failed to init rknn runtime for {filename}")
52+
return rknn_lite
53+
54+
55+
class RKNNModel:
56+
def __init__(self, model: str, target_platform="rk3588"):
57+
self.model = init_model(model)
58+
59+
def release(self):
60+
self.model.release()
61+
62+
def __call__(self, x: np.ndarray, h: np.ndarray, c: np.ndarray):
63+
"""
64+
Args:
65+
x: (1, 512), np.float32
66+
h: (2, 1, 64), np.float32
67+
c: (2, 1, 64), np.float32
68+
Returns:
69+
prob:
70+
next_h:
71+
next_c
72+
"""
73+
out, next_h, next_c = self.model.inference(inputs=[x, h, c])
74+
return out.item(), next_h, next_c
75+
76+
77+
def main():
78+
model = RKNNModel(model="./m.rknn")
79+
for i in range(1):
80+
test(model)
81+
82+
83+
def test(model):
84+
print("started")
85+
start = time.time()
86+
samples, sample_rate = load_audio("./lei-jun-test.wav")
87+
assert sample_rate == 16000, sample_rate
88+
89+
window_size = 512
90+
91+
h = np.zeros((2, 1, 64), dtype=np.float32)
92+
c = np.zeros((2, 1, 64), dtype=np.float32)
93+
94+
threshold = 0.5
95+
num_windows = samples.shape[0] // window_size
96+
out = []
97+
for i in range(num_windows):
98+
print(i, num_windows)
99+
this_samples = samples[i * window_size : (i + 1) * window_size]
100+
prob, h, c = model(this_samples[None], h, c)
101+
out.append(prob > threshold)
102+
103+
min_speech_duration = 0.25 * sample_rate / window_size
104+
min_silence_duration = 0.25 * sample_rate / window_size
105+
106+
result = []
107+
last = -1
108+
for k, f in enumerate(out):
109+
if f >= threshold:
110+
if last == -1:
111+
last = k
112+
elif last != -1:
113+
if k - last > min_speech_duration:
114+
result.append((last, k))
115+
last = -1
116+
117+
if last != -1 and k - last > min_speech_duration:
118+
result.append((last, k))
119+
120+
if not result:
121+
print("Empty for ./lei-jun-test.wav")
122+
return
123+
124+
print(result)
125+
126+
final = [result[0]]
127+
for r in result[1:]:
128+
f = final[-1]
129+
if r[0] - f[1] < min_silence_duration:
130+
final[-1] = (f[0], r[1])
131+
else:
132+
final.append(r)
133+
134+
for f in final:
135+
start = f[0] * window_size / sample_rate
136+
end = f[1] * window_size / sample_rate
137+
print("{:.3f} -- {:.3f}".format(start, end))
138+
139+
140+
if __name__ == "__main__":
141+
main()

scripts/silero_vad/v4/test-onnx.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,13 @@ def main():
9797
h, c = model.get_init_states()
9898
window_size = 512
9999
num_windows = samples.shape[0] // window_size
100+
100101
for i in range(num_windows):
101102
start = i * window_size
102103
end = start + window_size
104+
103105
p, h, c = model(samples[start:end], h, c)
106+
104107
probs.append(p[0].item())
105108

106109
threshold = 0.5

sherpa-onnx/csrc/CMakeLists.txt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ if(SHERPA_ONNX_ENABLE_RKNN)
159159
./rknn/online-transducer-modified-beam-search-decoder-rknn.cc
160160
./rknn/online-zipformer-ctc-model-rknn.cc
161161
./rknn/online-zipformer-transducer-model-rknn.cc
162+
./rknn/silero-vad-model-rknn.cc
162163
./rknn/utils.cc
163164
)
164165

@@ -468,6 +469,7 @@ if(SHERPA_ONNX_ENABLE_PORTAUDIO AND SHERPA_ONNX_ENABLE_BINARY)
468469
microphone.cc
469470
)
470471

472+
471473
add_executable(sherpa-onnx-microphone-offline
472474
sherpa-onnx-microphone-offline.cc
473475
microphone.cc
@@ -498,11 +500,11 @@ if(SHERPA_ONNX_ENABLE_PORTAUDIO AND SHERPA_ONNX_ENABLE_BINARY)
498500
)
499501

500502
set(exes
501-
sherpa-onnx-microphone
502503
sherpa-onnx-keyword-spotter-microphone
504+
sherpa-onnx-microphone
503505
sherpa-onnx-microphone-offline
504-
sherpa-onnx-microphone-offline-speaker-identification
505506
sherpa-onnx-microphone-offline-audio-tagging
507+
sherpa-onnx-microphone-offline-speaker-identification
506508
sherpa-onnx-vad-microphone
507509
sherpa-onnx-vad-microphone-offline-asr
508510
sherpa-onnx-vad-with-offline-asr

0 commit comments

Comments
 (0)