diff --git a/sherpa-onnx/csrc/silero-vad-model-config.cc b/sherpa-onnx/csrc/silero-vad-model-config.cc index 9237963ba4..7253bd4bcc 100644 --- a/sherpa-onnx/csrc/silero-vad-model-config.cc +++ b/sherpa-onnx/csrc/silero-vad-model-config.cc @@ -31,8 +31,7 @@ void SileroVadModelConfig::Register(ParseOptions *po) { po->Register( "silero-vad-max-speech-duration", &max_speech_duration, "In seconds. If a speech segment is longer than this value, then we " - "increase the threshold to 0.9. After finishing detecting the segment, " - "the threshold value is reset to its original value."); + "cut a segment."); po->Register( "silero-vad-window-size", &window_size, @@ -102,12 +101,12 @@ bool SileroVadModelConfig::Validate() const { std::string SileroVadModelConfig::ToString() const { std::ostringstream os; - os << "SileroVadModelConfig("; + os << "SilerVadModelConfig("; os << "model=\"" << model << "\", "; os << "threshold=" << threshold << ", "; os << "min_silence_duration=" << min_silence_duration << ", "; os << "min_speech_duration=" << min_speech_duration << ", "; - os << "max_speech_duration=" << max_speech_duration << ", "; + os << "max_speech_duration=" << max_speech_duration << ", "; os << "window_size=" << window_size << ")"; return os.str(); diff --git a/sherpa-onnx/csrc/silero-vad-model-config.h b/sherpa-onnx/csrc/silero-vad-model-config.h index 5ae06f5ff7..141ae53c02 100644 --- a/sherpa-onnx/csrc/silero-vad-model-config.h +++ b/sherpa-onnx/csrc/silero-vad-model-config.h @@ -27,10 +27,7 @@ struct SileroVadModelConfig { // 256, 512, 768 samples for 800 Hz int32_t window_size = 512; // in samples - // If a speech segment is longer than this value, then we increase - // the threshold to 0.9. After finishing detecting the segment, - // the threshold value is reset to its original value. - float max_speech_duration = 20; // in seconds + float max_speech_duration = 20; // in seconds SileroVadModelConfig() = default; diff --git a/sherpa-onnx/csrc/silero-vad-model.cc b/sherpa-onnx/csrc/silero-vad-model.cc index 66841d56d9..2b94a8d68b 100644 --- a/sherpa-onnx/csrc/silero-vad-model.cc +++ b/sherpa-onnx/csrc/silero-vad-model.cc @@ -11,6 +11,7 @@ #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/session.h" +#include "silero-vad-model.h" namespace sherpa_onnx { @@ -32,9 +33,13 @@ class SileroVadModel::Impl { } min_silence_samples_ = - sample_rate_ * config_.silero_vad.min_silence_duration; + (int32_t)(sample_rate_ * config_.silero_vad.min_silence_duration); - min_speech_samples_ = sample_rate_ * config_.silero_vad.min_speech_duration; + min_speech_samples_ = + (int32_t)(sample_rate_ * config_.silero_vad.min_speech_duration); + + max_speech_samples_ = + (int32_t)(sample_rate_ * config_.silero_vad.max_speech_duration); } #if __ANDROID_API__ >= 9 @@ -54,9 +59,13 @@ class SileroVadModel::Impl { } min_silence_samples_ = - sample_rate_ * config_.silero_vad.min_silence_duration; + (int32_t)(sample_rate_ * config_.silero_vad.min_silence_duration); + + min_speech_samples_ = + (int32_t)(sample_rate_ * config_.silero_vad.min_speech_duration); - min_speech_samples_ = sample_rate_ * config_.silero_vad.min_speech_duration; + max_speech_samples_ = + (int32_t)(sample_rate_ * config_.silero_vad.max_speech_duration); } #endif @@ -155,14 +164,34 @@ class SileroVadModel::Impl { int32_t MinSpeechDurationSamples() const { return min_speech_samples_; } + int32_t MaxSpeechDurationSamples() const { return max_speech_samples_; } + + float Threshold() { return config_.silero_vad.threshold; } + void SetMinSilenceDuration(float s) { - min_silence_samples_ = sample_rate_ * s; + min_silence_samples_ = (int32_t)(sample_rate_ * s); + } + + void SetMinSpeechDuration(float s) { + min_speech_samples_ = (int32_t)(sample_rate_ * s); + } + + void SetMaxSpeechDuration(float s) { + max_speech_samples_ = (int32_t)(sample_rate_ * s); } void SetThreshold(float threshold) { config_.silero_vad.threshold = threshold; } + float Run(const float *samples, int32_t n) { + if (is_v5_) { + return RunV5(samples, n); + } else { + return RunV4(samples, n); + } + } + private: void Init(void *model_data, size_t model_data_length) { sess_ = std::make_unique(env_, model_data, model_data_length, @@ -335,14 +364,6 @@ class SileroVadModel::Impl { } } - float Run(const float *samples, int32_t n) { - if (is_v5_) { - return RunV5(samples, n); - } else { - return RunV4(samples, n); - } - } - float RunV5(const float *samples, int32_t n) { auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); @@ -418,6 +439,7 @@ class SileroVadModel::Impl { int64_t sample_rate_; int32_t min_silence_samples_; int32_t min_speech_samples_; + int32_t max_speech_samples_; bool triggered_ = false; int32_t current_sample_ = 0; @@ -457,12 +479,30 @@ int32_t SileroVadModel::MinSpeechDurationSamples() const { return impl_->MinSpeechDurationSamples(); } +int32_t SileroVadModel::MaxSpeechDurationSamples() { + return impl_->MaxSpeechDurationSamples(); +} + +float SileroVadModel::Threshold() { return impl_->Threshold(); } + void SileroVadModel::SetMinSilenceDuration(float s) { impl_->SetMinSilenceDuration(s); } +void SileroVadModel::SetMinSpeechDuration(float s) { + impl_->SetMinSpeechDuration(s); +} + void SileroVadModel::SetThreshold(float threshold) { impl_->SetThreshold(threshold); } +void SileroVadModel::SetMaxSpeechDuration(float s) { + impl_->SetMaxSpeechDuration(s); +} + +float SileroVadModel::Run(const float *samples, int32_t n) { + return impl_->Run(samples, n); +} + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/silero-vad-model.h b/sherpa-onnx/csrc/silero-vad-model.h index 169cb72440..2cddc21a77 100644 --- a/sherpa-onnx/csrc/silero-vad-model.h +++ b/sherpa-onnx/csrc/silero-vad-model.h @@ -37,6 +37,8 @@ class SileroVadModel : public VadModel { */ bool IsSpeech(const float *samples, int32_t n) override; + float Run(const float *samples, int32_t n); + // For silero vad V4, it is WindowShift(). // For silero vad V5, it is WindowShift()+64 for 16kHz and // WindowShift()+32 for 8kHz @@ -47,9 +49,13 @@ class SileroVadModel : public VadModel { int32_t MinSilenceDurationSamples() const override; int32_t MinSpeechDurationSamples() const override; + int32_t MaxSpeechDurationSamples(); + float Threshold(); void SetMinSilenceDuration(float s) override; - void SetThreshold(float threshold) override; + void SetMinSpeechDuration(float s); + void SetMaxSpeechDuration(float s); + void SetThreshold(float threshold) override; private: class Impl; diff --git a/sherpa-onnx/csrc/voice-activity-detector.cc b/sherpa-onnx/csrc/voice-activity-detector.cc index c20d3476dd..f3d5178ee2 100644 --- a/sherpa-onnx/csrc/voice-activity-detector.cc +++ b/sherpa-onnx/csrc/voice-activity-detector.cc @@ -9,17 +9,64 @@ #include #include "sherpa-onnx/csrc/circular-buffer.h" -#include "sherpa-onnx/csrc/vad-model.h" +#include "sherpa-onnx/csrc/silero-vad-model.h" +// #define __DEBUG_SPEECH_PROB___ namespace sherpa_onnx { +class timestamp_t { + public: + int start; + int end; + + // default + parameterized constructor + timestamp_t(int start = -1, int end = -1) : start(start), end(end) {}; + + // assignment operator modifies object, therefore non-const + timestamp_t &operator=(const timestamp_t &a) { + start = a.start; + end = a.end; + return *this; + }; + + // equality comparison. doesn't modify object. therefore const. + bool operator==(const timestamp_t &a) const { + return (start == a.start && end == a.end); + }; +}; + class VoiceActivityDetector::Impl { public: explicit Impl(const VadModelConfig &config, float buffer_size_in_seconds = 60) - : model_(VadModel::Create(config)), + : model_(std::make_unique(config)), config_(config), - buffer_(buffer_size_in_seconds * config.sample_rate) { - Init(); + buffer_((int32_t)(buffer_size_in_seconds * config.sample_rate)) { + sample_rate = config.sample_rate; + int32_t sr_per_ms = sample_rate / 1000; + int32_t speech_pad_ms = 32; + + window_size = model_->WindowSize(); + window_shift = model_->WindowShift(); + threshold = model_->Threshold(); + + min_speech_samples = model_->MinSpeechDurationSamples(); + + speech_pad_samples = sr_per_ms * speech_pad_ms; + + max_speech_samples = model_->MaxSpeechDurationSamples() - window_shift - + 2 * speech_pad_samples; + + min_silence_samples = model_->MinSilenceDurationSamples(); + + min_silence_samples_at_max_speech = sr_per_ms * 98; + +#ifdef __DEBUG_SPEECH_PROB___ + printf( + "{window_size: %d, min_speech_samples:%d, max_speech_samples:%d, " + "min_silence_samples:%d, min_silence_samples_at_max_speech:%d}\n", + window_size, min_speech_samples, max_speech_samples, + min_silence_samples, min_silence_samples_at_max_speech); +#endif //__DEBUG_SPEECH_PROB___ } #if __ANDROID_API__ >= 9 @@ -27,27 +74,12 @@ class VoiceActivityDetector::Impl { float buffer_size_in_seconds = 60) : model_(VadModel::Create(mgr, config)), config_(config), - buffer_(buffer_size_in_seconds * config.sample_rate) { - Init(); - } + buffer_(buffer_size_in_seconds * config.sample_rate) {} #endif void AcceptWaveform(const float *samples, int32_t n) { - if (buffer_.Size() > max_utterance_length_) { - model_->SetMinSilenceDuration(new_min_silence_duration_s_); - model_->SetThreshold(new_threshold_); - } else { - model_->SetMinSilenceDuration(config_.silero_vad.min_silence_duration); - model_->SetThreshold(config_.silero_vad.threshold); - } - - int32_t window_size = model_->WindowSize(); - int32_t window_shift = model_->WindowShift(); - - // note n is usually window_size and there is no need to use - // an extra buffer here + buffer_.Push(samples, n); last_.insert(last_.end(), samples, samples + n); - if (last_.size() < window_size) { return; } @@ -56,52 +88,167 @@ class VoiceActivityDetector::Impl { int32_t k = (static_cast(last_.size()) - window_size) / window_shift + 1; const float *p = last_.data(); - bool is_speech = false; for (int32_t i = 0; i < k; ++i, p += window_shift) { - buffer_.Push(p, window_shift); - // NOTE(fangjun): Please don't use a very large n. - bool this_window_is_speech = model_->IsSpeech(p, window_size); - is_speech = is_speech || this_window_is_speech; - } - - last_ = std::vector( - p, static_cast(last_.data()) + last_.size()); - - if (is_speech) { - if (start_ == -1) { - // beginning of speech - start_ = std::max(buffer_.Tail() - 2 * model_->WindowSize() - - model_->MinSpeechDurationSamples(), - buffer_.Head()); + float speech_prob = model_->Run(p, window_size); + current_sample += window_shift; + // Voice fragmentation + if ((speech_prob >= threshold)) { +#ifdef __DEBUG_SPEECH_PROB___ + float speech = + current_sample - window_shift; // minus window_shift to get precise + // start time point. + printf("{ start: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, + speech_prob, current_sample - window_shift); +#endif //__DEBUG_SPEECH_PROB___ + // Temporary end point reset + if (temp_end != 0) { + temp_end = 0; + // The next estimated start point is less than the last end point, + // reset + if (next_start < prev_end) next_start = current_sample - window_shift; + } + // First voice segmentation, record start point + if (triggered == false) { + triggered = true; + current_speech.start = current_sample - window_shift; + } + continue; } - } else { - // non-speech - if (start_ != -1 && buffer_.Size()) { - // end of speech, save the speech segment - int32_t end = buffer_.Tail() - model_->MinSilenceDurationSamples(); - - std::vector s = buffer_.Get(start_, end - start_); - SpeechSegment segment; - - segment.start = start_; - segment.samples = std::move(s); - segments_.push(std::move(segment)); - - buffer_.Pop(end - buffer_.Head()); + if ( + // If the number of samples is greater than the maximum number of + // voice fragments, forced fragmentation + (triggered == true) && + ((current_sample - current_speech.start) > max_speech_samples)) { + if (prev_end > 0) { + current_speech.end = prev_end; +#ifdef __DEBUG_SPEECH_PROB___ + printf("{>max_prev speech start: %d, end:%d}\n", current_speech.start, + current_speech.end); +#endif //__DEBUG_SPEECH_PROB___ + std::vector s = buffer_.Get( + current_speech.start, current_speech.end - current_speech.start); + SpeechSegment segment; + segment.start = current_speech.start; + segment.samples = std::move(s); + segments_.push(std::move(segment)); + current_speech = timestamp_t(); + // previously reached silence(< neg_thres) and is still not speech(< + // thres) + if (next_start < prev_end) + triggered = false; + else { + current_speech.start = next_start; + } + prev_end = 0; + next_start = 0; + temp_end = 0; + } else { + current_speech.end = current_sample; +#ifdef __DEBUG_SPEECH_PROB___ + printf("{>max speech start: %d, end:%d}\n", current_speech.start, + current_speech.end); +#endif //__DEBUG_SPEECH_PROB___ + std::vector s = buffer_.Get( + current_speech.start, current_speech.end - current_speech.start); + SpeechSegment segment; + segment.start = current_speech.start; + segment.samples = std::move(s); + segments_.push(std::move(segment)); + current_speech = timestamp_t(); + prev_end = 0; + next_start = 0; + temp_end = 0; + triggered = false; + } + continue; + } + // Chaos, stay the same + if ((speech_prob >= (threshold - 0.15)) && (speech_prob < threshold)) { + if (triggered) { +#ifdef __DEBUG_SPEECH_PROB___ + float speech = + current_sample - window_shift; // minus window_shift to get + // precise start time point. + printf("{ speaking: %.3f s (%.3f) %08d}\n", + 1.0 * speech / sample_rate, speech_prob, + current_sample - window_shift); +#endif //__DEBUG_SPEECH_PROB___ + } else { +#ifdef __DEBUG_SPEECH_PROB___ + float speech = + current_sample - window_shift; // minus window_shift to get + // precise start time point. + printf("{ silence: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, + speech_prob, current_sample - window_shift); +#endif //__DEBUG_SPEECH_PROB___ + } + continue; } - if (start_ == -1) { - int32_t end = buffer_.Tail() - 2 * model_->WindowSize() - - model_->MinSpeechDurationSamples(); - int32_t n = std::max(0, end - buffer_.Head()); - if (n > 0) { - buffer_.Pop(n); + // 4) End + if ((speech_prob < (threshold - 0.15))) { +#ifdef __DEBUG_SPEECH_PROB___ + float speech = current_sample - window_shift - + speech_pad_samples; // minus window_shift to get precise + // start time point. + if (speech < 0.0f) { + speech = 0.0f; } + printf("{ end: %.3f s (%.3f) %08d}\n", 1.0 * speech / sample_rate, + speech_prob, current_sample - window_shift); +#endif //__DEBUG_SPEECH_PROB___ + if (triggered == true) { + //(The first silent segment after voice segmentation, recording + // possible end point) + if (temp_end == 0) { + temp_end = current_sample; + } + // (If it is greater than the maximum value of accumulated silence, + // the possible end point is recorded and used for forced segmentation + // of large audio clips.) + if (current_sample - temp_end > min_silence_samples_at_max_speech) + prev_end = temp_end; + // a. silence < min_slience_samples, continue speaking + if ((current_sample - temp_end) < min_silence_samples) { + } + // b. silence >= min_slience_samples, end speaking + else { + current_speech.end = temp_end; + if (current_speech.end - current_speech.start > + min_speech_samples) { +#ifdef __DEBUG_SPEECH_PROB___ + printf("{>min speech start: %d, end:%d}\n", current_speech.start, + current_speech.end); +#endif //__DEBUG_SPEECH_PROB___ + std::vector s = + buffer_.Get(current_speech.start, + current_speech.end - current_speech.start); + SpeechSegment segment; + segment.start = current_speech.start; + segment.samples = std::move(s); + segments_.push(std::move(segment)); + current_speech = timestamp_t(); + prev_end = 0; + next_start = 0; + temp_end = 0; + triggered = false; + } + } + } else { + // may first windows see end state. + } + continue; } + } + last_ = std::vector( + p, static_cast(last_.data()) + last_.size()); - start_ = -1; + if (current_speech.start > 0) { + buffer_.Pop(current_speech.start - buffer_.Head()); + } else { + buffer_.Pop(current_sample - buffer_.Head()); } } @@ -115,62 +262,60 @@ class VoiceActivityDetector::Impl { void Reset() { std::queue().swap(segments_); - model_->Reset(); buffer_.Reset(); - - start_ = -1; + // Reset related variables + current_sample = 0; + current_speech = timestamp_t(); + prev_end = 0; + next_start = 0; + temp_end = 0; + triggered = false; + last_.clear(); } void Flush() { - if (start_ == -1 || buffer_.Size() == 0) { - return; + int32_t buffer_size = buffer_.Size(); + + if (buffer_size >= window_size) { + std::vector s = buffer_.Get(buffer_.Head(), buffer_size); + SpeechSegment segment; + segment.start = current_sample; + segment.samples = std::move(s); + segments_.push(std::move(segment)); + buffer_.Pop(buffer_size); } - - int32_t end = buffer_.Tail(); - if (end <= start_) { - return; - } - - std::vector s = buffer_.Get(start_, end - start_); - - SpeechSegment segment; - - segment.start = start_; - segment.samples = std::move(s); - - segments_.push(std::move(segment)); - - buffer_.Pop(end - buffer_.Head()); - start_ = -1; } - bool IsSpeechDetected() const { return start_ != -1; } + bool IsSpeechDetected() const { return !segments_.empty(); } const VadModelConfig &GetConfig() const { return config_; } - private: - void Init() { - // TODO(fangjun): Currently, we support only one vad model. - // If a new vad model is added, we need to change the place - // where max_speech_duration is placed. - max_utterance_length_ = - config_.sample_rate * config_.silero_vad.max_speech_duration; - } - private: std::queue segments_; + timestamp_t current_speech; - std::unique_ptr model_; + std::unique_ptr model_; VadModelConfig config_; CircularBuffer buffer_; std::vector last_; - - int max_utterance_length_ = -1; // in samples - float new_min_silence_duration_s_ = 0.1; - float new_threshold_ = 0.90; - - int32_t start_ = -1; + int32_t window_size; + int32_t window_shift; + int32_t sample_rate; + float threshold; + int32_t min_silence_samples; // sr_per_ms * #ms + int32_t min_silence_samples_at_max_speech; // sr_per_ms * #98 + int32_t min_speech_samples; // sr_per_ms * #ms + int32_t max_speech_samples; + int32_t speech_pad_samples; // usually a + + // model states + bool triggered = false; + unsigned int temp_end = 0; + unsigned int current_sample = 0; + // MAX 4294967295 samples / 8sample per ms / 1000 / 60 = 8947 minutes + int32_t prev_end = 0; + int32_t next_start = 0; }; VoiceActivityDetector::VoiceActivityDetector(